diff --git a/.github/workflows/on_pull_request.yaml b/.github/workflows/on_pull_request.yaml index 78e608ee9..9e44c806a 100644 --- a/.github/workflows/on_pull_request.yaml +++ b/.github/workflows/on_pull_request.yaml @@ -20,7 +20,7 @@ jobs: # not to have GHA download an (at the time of writing) 4 GB cache # of PyTorch and other dependencies. - name: Install Ruff - run: pip install ruff==0.0.272 + run: pip install ruff==0.1.6 - name: Run Ruff run: ruff . lint-js: diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml index 3dafaf8dc..f42e4758e 100644 --- a/.github/workflows/run_tests.yaml +++ b/.github/workflows/run_tests.yaml @@ -20,6 +20,12 @@ jobs: cache-dependency-path: | **/requirements*txt launch.py + - name: Cache models + id: cache-models + uses: actions/cache@v3 + with: + path: models + key: "2023-12-30" - name: Install test dependencies run: pip install wait-for-it -r requirements-test.txt env: @@ -33,6 +39,8 @@ jobs: TORCH_INDEX_URL: https://download.pytorch.org/whl/cpu WEBUI_LAUNCH_LIVE_OUTPUT: "1" PYTHONUNBUFFERED: "1" + - name: Print installed packages + run: pip freeze - name: Start test server run: > python -m coverage run @@ -49,7 +57,7 @@ jobs: 2>&1 | tee output.txt & - name: Run tests run: | - wait-for-it --service 127.0.0.1:7860 -t 600 + wait-for-it --service 127.0.0.1:7860 -t 20 python -m pytest -vv --junitxml=test/results.xml --cov . --cov-report=xml --verify-base-url test - name: Kill test server if: always() diff --git a/.gitignore b/.gitignore index 09734267f..6790e9ee7 100644 --- a/.gitignore +++ b/.gitignore @@ -37,3 +37,4 @@ notification.mp3 /node_modules /package-lock.json /.coverage* +/test/test_outputs diff --git a/CHANGELOG.md b/CHANGELOG.md index 1cd3572c8..67429bbff 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,170 @@ +## 1.7.0 + +### Features: +* settings tab rework: add search field, add categories, split UI settings page into many +* add altdiffusion-m18 support ([#13364](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13364)) +* support inference with LyCORIS GLora networks ([#13610](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13610)) +* add lora-embedding bundle system ([#13568](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13568)) +* option to move prompt from top row into generation parameters +* add support for SSD-1B ([#13865](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13865)) +* support inference with OFT networks ([#13692](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13692)) +* script metadata and DAG sorting mechanism ([#13944](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13944)) +* support HyperTile optimization ([#13948](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13948)) +* add support for SD 2.1 Turbo ([#14170](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14170)) +* remove Train->Preprocessing tab and put all its functionality into Extras tab +* initial IPEX support for Intel Arc GPU ([#14171](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14171)) + +### Minor: +* allow reading model hash from images in img2img batch mode ([#12767](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12767)) +* add option to align with sgm repo's sampling implementation ([#12818](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12818)) +* extra field for lora metadata viewer: `ss_output_name` ([#12838](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12838)) +* add action in settings page to calculate all SD checkpoint hashes ([#12909](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12909)) +* add button to copy prompt to style editor ([#12975](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12975)) +* add --skip-load-model-at-start option ([#13253](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13253)) +* write infotext to gif images +* read infotext from gif images ([#13068](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13068)) +* allow configuring the initial state of InputAccordion in ui-config.json ([#13189](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13189)) +* allow editing whitespace delimiters for ctrl+up/ctrl+down prompt editing ([#13444](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13444)) +* prevent accidentally closing popup dialogs ([#13480](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13480)) +* added option to play notification sound or not ([#13631](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13631)) +* show the preview image in the full screen image viewer if available ([#13459](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13459)) +* support for webui.settings.bat ([#13638](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13638)) +* add an option to not print stack traces on ctrl+c +* start/restart generation by Ctrl (Alt) + Enter ([#13644](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13644)) +* update prompts_from_file script to allow concatenating entries with the general prompt ([#13733](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13733)) +* added a visible checkbox to input accordion +* added an option to hide all txt2img/img2img parameters in an accordion ([#13826](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13826)) +* added 'Path' sorting option for Extra network cards ([#13968](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13968)) +* enable prompt hotkeys in style editor ([#13931](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13931)) +* option to show batch img2img results in UI ([#14009](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14009)) +* infotext updates: add option to disregard certain infotext fields, add option to not include VAE in infotext, add explanation to infotext settings page, move some options to infotext settings page +* add FP32 fallback support on sd_vae_approx ([#14046](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046)) +* support XYZ scripts / split hires path from unet ([#14126](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14126)) +* allow use of mutiple styles csv files ([#14125](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14125)) + +### Extensions and API: +* update gradio to 3.41.2 +* support installed extensions list api ([#12774](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12774)) +* update pnginfo API to return dict with parsed values +* add noisy latent to `ExtraNoiseParams` for callback ([#12856](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12856)) +* show extension datetime in UTC ([#12864](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12864), [#12865](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12865), [#13281](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13281)) +* add an option to choose how to combine hires fix and refiner +* include program version in info response. ([#13135](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13135)) +* sd_unet support for SDXL +* patch DDPM.register_betas so that users can put given_betas in model yaml ([#13276](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13276)) +* xyz_grid: add prepare ([#13266](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13266)) +* allow multiple localization files with same language in extensions ([#13077](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13077)) +* add onEdit function for js and rework token-counter.js to use it +* fix the key error exception when processing override_settings keys ([#13567](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13567)) +* ability for extensions to return custom data via api in response.images ([#13463](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13463)) +* call state.jobnext() before postproces*() ([#13762](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13762)) +* add option to set notification sound volume ([#13884](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13884)) +* update Ruff to 0.1.6 ([#14059](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14059)) +* add Block component creation callback ([#14119](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14119)) +* catch uncaught exception with ui creation scripts ([#14120](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14120)) +* use extension name for determining an extension is installed in the index ([#14063](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14063)) +* update is_installed() from launch_utils.py to fix reinstalling already installed packages ([#14192](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14192)) + +### Bug Fixes: +* fix pix2pix producing bad results +* fix defaults settings page breaking when any of main UI tabs are hidden +* fix error that causes some extra networks to be disabled if both and are present in the prompt +* fix for Reload UI function: if you reload UI on one tab, other opened tabs will no longer stop working +* prevent duplicate resize handler ([#12795](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12795)) +* small typo: vae resolve bug ([#12797](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12797)) +* hide broken image crop tool ([#12792](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12792)) +* don't show hidden samplers in dropdown for XYZ script ([#12780](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12780)) +* fix style editing dialog breaking if it's opened in both img2img and txt2img tabs +* hide --gradio-auth and --api-auth values from /internal/sysinfo report +* add missing infotext for RNG in options ([#12819](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12819)) +* fix notification not playing when built-in webui tab is inactive ([#12834](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12834)) +* honor `--skip-install` for extension installers ([#12832](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12832)) +* don't print blank stdout in extension installers ([#12833](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12833), [#12855](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12855)) +* get progressbar to display correctly in extensions tab +* keep order in list of checkpoints when loading model that doesn't have a checksum +* fix inpainting models in txt2img creating black pictures +* fix generation params regex ([#12876](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12876)) +* fix batch img2img output dir with script ([#12926](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12926)) +* fix #13080 - Hypernetwork/TI preview generation ([#13084](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13084)) +* fix bug with sigma min/max overrides. ([#12995](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12995)) +* more accurate check for enabling cuDNN benchmark on 16XX cards ([#12924](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12924)) +* don't use multicond parser for negative prompt counter ([#13118](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13118)) +* fix data-sort-name containing spaces ([#13412](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13412)) +* update card on correct tab when editing metadata ([#13411](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13411)) +* fix viewing/editing metadata when filename contains an apostrophe ([#13395](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13395)) +* fix: --sd_model in "Prompts from file or textbox" script is not working ([#13302](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13302)) +* better Support for Portable Git ([#13231](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13231)) +* fix issues when webui_dir is not work_dir ([#13210](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13210)) +* fix: lora-bias-backup don't reset cache ([#13178](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13178)) +* account for customizable extra network separators whyen removing extra network text from the prompt ([#12877](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12877)) +* re fix batch img2img output dir with script ([#13170](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13170)) +* fix `--ckpt-dir` path separator and option use `short name` for checkpoint dropdown ([#13139](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13139)) +* consolidated allowed preview formats, Fix extra network `.gif` not woking as preview ([#13121](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13121)) +* fix venv_dir=- environment variable not working as expected on linux ([#13469](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13469)) +* repair unload sd checkpoint button +* edit-attention fixes ([#13533](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13533)) +* fix bug when using --gfpgan-models-path ([#13718](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13718)) +* properly apply sort order for extra network cards when selected from dropdown +* fixes generation restart not working for some users when 'Ctrl+Enter' is pressed ([#13962](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13962)) +* thread safe extra network list_items ([#13014](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13014)) +* fix not able to exit metadata popup when pop up is too big ([#14156](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14156)) +* fix auto focal point crop for opencv >= 4.8 ([#14121](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14121)) +* make 'use-cpu all' actually apply to 'all' ([#14131](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14131)) +* extras tab batch: actually use original filename +* make webui not crash when running with --disable-all-extensions option + +### Other: +* non-local condition ([#12814](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12814)) +* fix minor typos ([#12827](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12827)) +* remove xformers Python version check ([#12842](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12842)) +* style: file-metadata word-break ([#12837](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12837)) +* revert SGM noise multiplier change for img2img because it breaks hires fix +* do not change quicksettings dropdown option when value returned is `None` ([#12854](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12854)) +* [RC 1.6.0 - zoom is partly hidden] Update style.css ([#12839](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12839)) +* chore: change extension time format ([#12851](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12851)) +* WEBUI.SH - Use torch 2.1.0 release candidate for Navi 3 ([#12929](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12929)) +* add Fallback at images.read_info_from_image if exif data was invalid ([#13028](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13028)) +* update cmd arg description ([#12986](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12986)) +* fix: update shared.opts.data when add_option ([#12957](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12957), [#13213](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13213)) +* restore missing tooltips ([#12976](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12976)) +* use default dropdown padding on mobile ([#12880](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12880)) +* put enable console prompts option into settings from commandline args ([#13119](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13119)) +* fix some deprecated types ([#12846](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12846)) +* bump to torchsde==0.2.6 ([#13418](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13418)) +* update dragdrop.js ([#13372](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13372)) +* use orderdict as lru cache:opt/bug ([#13313](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13313)) +* XYZ if not include sub grids do not save sub grid ([#13282](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13282)) +* initialize state.time_start befroe state.job_count ([#13229](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13229)) +* fix fieldname regex ([#13458](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13458)) +* change denoising_strength default to None. ([#13466](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13466)) +* fix regression ([#13475](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13475)) +* fix IndexError ([#13630](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13630)) +* fix: checkpoints_loaded:{checkpoint:state_dict}, model.load_state_dict issue in dict value empty ([#13535](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13535)) +* update bug_report.yml ([#12991](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12991)) +* requirements_versions httpx==0.24.1 ([#13839](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13839)) +* fix parenthesis auto selection ([#13829](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13829)) +* fix #13796 ([#13797](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13797)) +* corrected a typo in `modules/cmd_args.py` ([#13855](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13855)) +* feat: fix randn found element of type float at pos 2 ([#14004](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14004)) +* adds tqdm handler to logging_config.py for progress bar integration ([#13996](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13996)) +* hotfix: call shared.state.end() after postprocessing done ([#13977](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13977)) +* fix dependency address patch 1 ([#13929](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13929)) +* save sysinfo as .json ([#14035](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14035)) +* move exception_records related methods to errors.py ([#14084](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14084)) +* compatibility ([#13936](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13936)) +* json.dump(ensure_ascii=False) ([#14108](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14108)) +* dir buttons start with / so only the correct dir will be shown and no… ([#13957](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13957)) +* alternate implementation for unet forward replacement that does not depend on hijack being applied +* re-add `keyedit_delimiters_whitespace` setting lost as part of commit e294e46 ([#14178](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14178)) +* fix `save_samples` being checked early when saving masked composite ([#14177](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14177)) +* slight optimization for mask and mask_composite ([#14181](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14181)) +* add import_hook hack to work around basicsr/torchvision incompatibility ([#14186](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14186)) + +## 1.6.1 + +### Bug Fixes: + * fix an error causing the webui to fail to start ([#13839](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13839)) + ## 1.6.0 ### Features: diff --git a/README.md b/README.md index c7a4e363f..9f9f33b12 100644 --- a/README.md +++ b/README.md @@ -91,6 +91,7 @@ A browser interface based on Gradio library for Stable Diffusion. - Eased resolution restriction: generated image's dimensions must be a multiple of 8 rather than 64 - Now with a license! - Reorder elements in the UI from settings screen +- [Segmind Stable Diffusion](https://huggingface.co/segmind/SSD-1B) support ## Installation and Running Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for: @@ -120,7 +121,9 @@ Alternatively, use online services (like Google Colab): # Debian-based: sudo apt install wget git python3 python3-venv libgl1 libglib2.0-0 # Red Hat-based: -sudo dnf install wget git python3 +sudo dnf install wget git python3 gperftools-libs libglvnd-glx +# openSUSE-based: +sudo zypper install wget git python3 libtcmalloc4 libglvnd # Arch-based: sudo pacman -S wget git python3 ``` @@ -146,7 +149,7 @@ For the purposes of getting Google and other search engines to crawl the wiki, h ## Credits Licenses for borrowed code can be found in `Settings -> Licenses` screen, and also in `html/licenses.html` file. -- Stable Diffusion - https://github.com/CompVis/stable-diffusion, https://github.com/CompVis/taming-transformers +- Stable Diffusion - https://github.com/Stability-AI/stablediffusion, https://github.com/CompVis/taming-transformers - k-diffusion - https://github.com/crowsonkb/k-diffusion.git - GFPGAN - https://github.com/TencentARC/GFPGAN.git - CodeFormer - https://github.com/sczhou/CodeFormer @@ -173,5 +176,6 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al - TAESD - Ollin Boer Bohan - https://github.com/madebyollin/taesd - LyCORIS - KohakuBlueleaf - Restart sampling - lambertae - https://github.com/Newbeeer/diffusion_restart_sampling +- Hypertile - tfernd - https://github.com/tfernd/HyperTile - Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user. - (You) diff --git a/configs/sd_xl_inpaint.yaml b/configs/sd_xl_inpaint.yaml new file mode 100644 index 000000000..3bad37218 --- /dev/null +++ b/configs/sd_xl_inpaint.yaml @@ -0,0 +1,98 @@ +model: + target: sgm.models.diffusion.DiffusionEngine + params: + scale_factor: 0.13025 + disable_first_stage_autocast: True + + denoiser_config: + target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser + params: + num_idx: 1000 + + weighting_config: + target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting + scaling_config: + target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization + + network_config: + target: sgm.modules.diffusionmodules.openaimodel.UNetModel + params: + adm_in_channels: 2816 + num_classes: sequential + use_checkpoint: True + in_channels: 9 + out_channels: 4 + model_channels: 320 + attention_resolutions: [4, 2] + num_res_blocks: 2 + channel_mult: [1, 2, 4] + num_head_channels: 64 + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: [1, 2, 10] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16 + context_dim: 2048 + spatial_transformer_attn_type: softmax-xformers + legacy: False + + conditioner_config: + target: sgm.modules.GeneralConditioner + params: + emb_models: + # crossattn cond + - is_trainable: False + input_key: txt + target: sgm.modules.encoders.modules.FrozenCLIPEmbedder + params: + layer: hidden + layer_idx: 11 + # crossattn and vector cond + - is_trainable: False + input_key: txt + target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2 + params: + arch: ViT-bigG-14 + version: laion2b_s39b_b160k + freeze: True + layer: penultimate + always_return_pooled: True + legacy: False + # vector cond + - is_trainable: False + input_key: original_size_as_tuple + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 # multiplied by two + # vector cond + - is_trainable: False + input_key: crop_coords_top_left + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 # multiplied by two + # vector cond + - is_trainable: False + input_key: target_size_as_tuple + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 # multiplied by two + + first_stage_config: + target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + attn_type: vanilla-xformers + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [1, 2, 4, 4] + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity diff --git a/extensions-builtin/Lora/lyco_helpers.py b/extensions-builtin/Lora/lyco_helpers.py index 279b34bc9..1679a0ce6 100644 --- a/extensions-builtin/Lora/lyco_helpers.py +++ b/extensions-builtin/Lora/lyco_helpers.py @@ -19,3 +19,50 @@ def rebuild_cp_decomposition(up, down, mid): up = up.reshape(up.size(0), -1) down = down.reshape(down.size(0), -1) return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down) + + +# copied from https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/lokr.py +def factorization(dimension: int, factor:int=-1) -> tuple[int, int]: + ''' + return a tuple of two value of input dimension decomposed by the number closest to factor + second value is higher or equal than first value. + + In LoRA with Kroneckor Product, first value is a value for weight scale. + secon value is a value for weight. + + Becuase of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different. + + examples) + factor + -1 2 4 8 16 ... + 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 + 128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16 + 250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25 + 360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30 + 512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32 + 1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64 + ''' + + if factor > 0 and (dimension % factor) == 0: + m = factor + n = dimension // factor + if m > n: + n, m = m, n + return m, n + if factor < 0: + factor = dimension + m, n = 1, dimension + length = m + n + while m length or new_m>factor: + break + else: + m, n = new_m, new_n + if m > n: + n, m = m, n + return m, n + diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py index 6021fd8de..a62e5eff9 100644 --- a/extensions-builtin/Lora/network.py +++ b/extensions-builtin/Lora/network.py @@ -137,7 +137,7 @@ class NetworkModule: def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): if self.bias is not None: updown = updown.reshape(self.bias.shape) - updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype) + updown += self.bias.to(orig_weight.device, dtype=updown.dtype) updown = updown.reshape(output_shape) if len(output_shape) == 4: diff --git a/extensions-builtin/Lora/network_full.py b/extensions-builtin/Lora/network_full.py index bf6930e96..f221c95f3 100644 --- a/extensions-builtin/Lora/network_full.py +++ b/extensions-builtin/Lora/network_full.py @@ -18,9 +18,9 @@ class NetworkModuleFull(network.NetworkModule): def calc_updown(self, orig_weight): output_shape = self.weight.shape - updown = self.weight.to(orig_weight.device, dtype=orig_weight.dtype) + updown = self.weight.to(orig_weight.device) if self.ex_bias is not None: - ex_bias = self.ex_bias.to(orig_weight.device, dtype=orig_weight.dtype) + ex_bias = self.ex_bias.to(orig_weight.device) else: ex_bias = None diff --git a/extensions-builtin/Lora/network_glora.py b/extensions-builtin/Lora/network_glora.py index 492d48707..efe5c6814 100644 --- a/extensions-builtin/Lora/network_glora.py +++ b/extensions-builtin/Lora/network_glora.py @@ -22,12 +22,12 @@ class NetworkModuleGLora(network.NetworkModule): self.w2b = weights.w["b2.weight"] def calc_updown(self, orig_weight): - w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype) - w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype) - w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype) - w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype) + w1a = self.w1a.to(orig_weight.device) + w1b = self.w1b.to(orig_weight.device) + w2a = self.w2a.to(orig_weight.device) + w2b = self.w2b.to(orig_weight.device) output_shape = [w1a.size(0), w1b.size(1)] - updown = ((w2b @ w1b) + ((orig_weight @ w2a) @ w1a)) + updown = ((w2b @ w1b) + ((orig_weight.to(dtype = w1a.dtype) @ w2a) @ w1a)) return self.finalize_updown(updown, orig_weight, output_shape) diff --git a/extensions-builtin/Lora/network_hada.py b/extensions-builtin/Lora/network_hada.py index 5fcb0695f..d95a0fd18 100644 --- a/extensions-builtin/Lora/network_hada.py +++ b/extensions-builtin/Lora/network_hada.py @@ -27,16 +27,16 @@ class NetworkModuleHada(network.NetworkModule): self.t2 = weights.w.get("hada_t2") def calc_updown(self, orig_weight): - w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype) - w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype) - w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype) - w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype) + w1a = self.w1a.to(orig_weight.device) + w1b = self.w1b.to(orig_weight.device) + w2a = self.w2a.to(orig_weight.device) + w2b = self.w2b.to(orig_weight.device) output_shape = [w1a.size(0), w1b.size(1)] if self.t1 is not None: output_shape = [w1a.size(1), w1b.size(1)] - t1 = self.t1.to(orig_weight.device, dtype=orig_weight.dtype) + t1 = self.t1.to(orig_weight.device) updown1 = lyco_helpers.make_weight_cp(t1, w1a, w1b) output_shape += t1.shape[2:] else: @@ -45,7 +45,7 @@ class NetworkModuleHada(network.NetworkModule): updown1 = lyco_helpers.rebuild_conventional(w1a, w1b, output_shape) if self.t2 is not None: - t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype) + t2 = self.t2.to(orig_weight.device) updown2 = lyco_helpers.make_weight_cp(t2, w2a, w2b) else: updown2 = lyco_helpers.rebuild_conventional(w2a, w2b, output_shape) diff --git a/extensions-builtin/Lora/network_ia3.py b/extensions-builtin/Lora/network_ia3.py index 7edc42497..96faeaf3e 100644 --- a/extensions-builtin/Lora/network_ia3.py +++ b/extensions-builtin/Lora/network_ia3.py @@ -17,7 +17,7 @@ class NetworkModuleIa3(network.NetworkModule): self.on_input = weights.w["on_input"].item() def calc_updown(self, orig_weight): - w = self.w.to(orig_weight.device, dtype=orig_weight.dtype) + w = self.w.to(orig_weight.device) output_shape = [w.size(0), orig_weight.size(1)] if self.on_input: diff --git a/extensions-builtin/Lora/network_lokr.py b/extensions-builtin/Lora/network_lokr.py index 340acdab3..fcdaeafd8 100644 --- a/extensions-builtin/Lora/network_lokr.py +++ b/extensions-builtin/Lora/network_lokr.py @@ -37,22 +37,22 @@ class NetworkModuleLokr(network.NetworkModule): def calc_updown(self, orig_weight): if self.w1 is not None: - w1 = self.w1.to(orig_weight.device, dtype=orig_weight.dtype) + w1 = self.w1.to(orig_weight.device) else: - w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype) - w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype) + w1a = self.w1a.to(orig_weight.device) + w1b = self.w1b.to(orig_weight.device) w1 = w1a @ w1b if self.w2 is not None: - w2 = self.w2.to(orig_weight.device, dtype=orig_weight.dtype) + w2 = self.w2.to(orig_weight.device) elif self.t2 is None: - w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype) - w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype) + w2a = self.w2a.to(orig_weight.device) + w2b = self.w2b.to(orig_weight.device) w2 = w2a @ w2b else: - t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype) - w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype) - w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype) + t2 = self.t2.to(orig_weight.device) + w2a = self.w2a.to(orig_weight.device) + w2b = self.w2b.to(orig_weight.device) w2 = lyco_helpers.make_weight_cp(t2, w2a, w2b) output_shape = [w1.size(0) * w2.size(0), w1.size(1) * w2.size(1)] diff --git a/extensions-builtin/Lora/network_lora.py b/extensions-builtin/Lora/network_lora.py index 26c0a72c2..4cc402951 100644 --- a/extensions-builtin/Lora/network_lora.py +++ b/extensions-builtin/Lora/network_lora.py @@ -61,13 +61,13 @@ class NetworkModuleLora(network.NetworkModule): return module def calc_updown(self, orig_weight): - up = self.up_model.weight.to(orig_weight.device, dtype=orig_weight.dtype) - down = self.down_model.weight.to(orig_weight.device, dtype=orig_weight.dtype) + up = self.up_model.weight.to(orig_weight.device) + down = self.down_model.weight.to(orig_weight.device) output_shape = [up.size(0), down.size(1)] if self.mid_model is not None: # cp-decomposition - mid = self.mid_model.weight.to(orig_weight.device, dtype=orig_weight.dtype) + mid = self.mid_model.weight.to(orig_weight.device) updown = lyco_helpers.rebuild_cp_decomposition(up, down, mid) output_shape += mid.shape[2:] else: diff --git a/extensions-builtin/Lora/network_norm.py b/extensions-builtin/Lora/network_norm.py index ce4501580..d25afcbb9 100644 --- a/extensions-builtin/Lora/network_norm.py +++ b/extensions-builtin/Lora/network_norm.py @@ -18,10 +18,10 @@ class NetworkModuleNorm(network.NetworkModule): def calc_updown(self, orig_weight): output_shape = self.w_norm.shape - updown = self.w_norm.to(orig_weight.device, dtype=orig_weight.dtype) + updown = self.w_norm.to(orig_weight.device) if self.b_norm is not None: - ex_bias = self.b_norm.to(orig_weight.device, dtype=orig_weight.dtype) + ex_bias = self.b_norm.to(orig_weight.device) else: ex_bias = None diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py new file mode 100644 index 000000000..fa647020f --- /dev/null +++ b/extensions-builtin/Lora/network_oft.py @@ -0,0 +1,82 @@ +import torch +import network +from lyco_helpers import factorization +from einops import rearrange + + +class ModuleTypeOFT(network.ModuleType): + def create_module(self, net: network.Network, weights: network.NetworkWeights): + if all(x in weights.w for x in ["oft_blocks"]) or all(x in weights.w for x in ["oft_diag"]): + return NetworkModuleOFT(net, weights) + + return None + +# Supports both kohya-ss' implementation of COFT https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py +# and KohakuBlueleaf's implementation of OFT/COFT https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/diag_oft.py +class NetworkModuleOFT(network.NetworkModule): + def __init__(self, net: network.Network, weights: network.NetworkWeights): + + super().__init__(net, weights) + + self.lin_module = None + self.org_module: list[torch.Module] = [self.sd_module] + + self.scale = 1.0 + + # kohya-ss + if "oft_blocks" in weights.w.keys(): + self.is_kohya = True + self.oft_blocks = weights.w["oft_blocks"] # (num_blocks, block_size, block_size) + self.alpha = weights.w["alpha"] # alpha is constraint + self.dim = self.oft_blocks.shape[0] # lora dim + # LyCORIS + elif "oft_diag" in weights.w.keys(): + self.is_kohya = False + self.oft_blocks = weights.w["oft_diag"] + # self.alpha is unused + self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size) + + is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear] + is_conv = type(self.sd_module) in [torch.nn.Conv2d] + is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] # unsupported + + if is_linear: + self.out_dim = self.sd_module.out_features + elif is_conv: + self.out_dim = self.sd_module.out_channels + elif is_other_linear: + self.out_dim = self.sd_module.embed_dim + + if self.is_kohya: + self.constraint = self.alpha * self.out_dim + self.num_blocks = self.dim + self.block_size = self.out_dim // self.dim + else: + self.constraint = None + self.block_size, self.num_blocks = factorization(self.out_dim, self.dim) + + def calc_updown(self, orig_weight): + oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + eye = torch.eye(self.block_size, device=self.oft_blocks.device) + + if self.is_kohya: + block_Q = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix + norm_Q = torch.norm(block_Q.flatten()) + new_norm_Q = torch.clamp(norm_Q, max=self.constraint) + block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) + oft_blocks = torch.matmul(eye + block_Q, (eye - block_Q).float().inverse()) + + R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + + # This errors out for MultiheadAttention, might need to be handled up-stream + merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) + merged_weight = torch.einsum( + 'k n m, k n ... -> k m ...', + R, + merged_weight + ) + merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') + + updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight + output_shape = orig_weight.shape + return self.finalize_updown(updown, orig_weight, output_shape) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 60d8dec4c..72ebd6241 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -1,3 +1,4 @@ +import gradio as gr import logging import os import re @@ -11,6 +12,7 @@ import network_ia3 import network_lokr import network_full import network_norm +import network_oft import torch from typing import Union @@ -28,6 +30,7 @@ module_types = [ network_full.ModuleTypeFull(), network_norm.ModuleTypeNorm(), network_glora.ModuleTypeGLora(), + network_oft.ModuleTypeOFT(), ] @@ -157,7 +160,8 @@ def load_network(name, network_on_disk): bundle_embeddings = {} for key_network, weight in sd.items(): - key_network_without_network_parts, network_part = key_network.split(".", 1) + key_network_without_network_parts, _, network_part = key_network.partition(".") + if key_network_without_network_parts == "bundle_emb": emb_name, vec_name = network_part.split(".", 1) emb_dict = bundle_embeddings.get(emb_name, {}) @@ -189,6 +193,17 @@ def load_network(name, network_on_disk): key = key_network_without_network_parts.replace("lora_te1_text_model", "transformer_text_model") sd_module = shared.sd_model.network_layer_mapping.get(key, None) + # kohya_ss OFT module + elif sd_module is None and "oft_unet" in key_network_without_network_parts: + key = key_network_without_network_parts.replace("oft_unet", "diffusion_model") + sd_module = shared.sd_model.network_layer_mapping.get(key, None) + + # KohakuBlueLeaf OFT module + if sd_module is None and "oft_diag" in key: + key = key_network_without_network_parts.replace("lora_unet", "diffusion_model") + key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model") + sd_module = shared.sd_model.network_layer_mapping.get(key, None) + if sd_module is None: keys_failed_to_match[key_network] = key continue @@ -300,7 +315,12 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No emb_db.skipped_embeddings[name] = embedding if failed_to_load_networks: - sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks)) + lora_not_found_message = f'Lora not found: {", ".join(failed_to_load_networks)}' + sd_hijack.model_hijack.comments.append(lora_not_found_message) + if shared.opts.lora_not_found_warning_console: + print(f'\n{lora_not_found_message}\n') + if shared.opts.lora_not_found_gradio_warning: + gr.Warning(lora_not_found_message) purge_networks_from_memory() @@ -375,18 +395,26 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn if module is not None and hasattr(self, 'weight'): try: with torch.no_grad(): - updown, ex_bias = module.calc_updown(self.weight) + if getattr(self, 'fp16_weight', None) is None: + weight = self.weight + bias = self.bias + else: + weight = self.fp16_weight.clone().to(self.weight.device) + bias = getattr(self, 'fp16_bias', None) + if bias is not None: + bias = bias.clone().to(self.bias.device) + updown, ex_bias = module.calc_updown(weight) - if len(self.weight.shape) == 4 and self.weight.shape[1] == 9: + if len(weight.shape) == 4 and weight.shape[1] == 9: # inpainting model. zero pad updown to make channel[1] 4 to 9 updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5)) - self.weight += updown + self.weight.copy_((weight.to(dtype=updown.dtype) + updown).to(dtype=self.weight.dtype)) if ex_bias is not None and hasattr(self, 'bias'): if self.bias is None: - self.bias = torch.nn.Parameter(ex_bias) + self.bias = torch.nn.Parameter(ex_bias).to(self.weight.dtype) else: - self.bias += ex_bias + self.bias.copy_((bias + ex_bias).to(dtype=self.bias.dtype)) except RuntimeError as e: logging.debug(f"Network {net.name} layer {network_layer_name}: {e}") extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1 diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index ef23968c5..1518f7e5c 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -39,6 +39,8 @@ shared.options_templates.update(shared.options_section(('extra_networks', "Extra "lora_show_all": shared.OptionInfo(False, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"), "lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}), "lora_in_memory_limit": shared.OptionInfo(0, "Number of Lora networks to keep cached in memory", gr.Number, {"precision": 0}), + "lora_not_found_warning_console": shared.OptionInfo(False, "Lora not found warning in console"), + "lora_not_found_gradio_warning": shared.OptionInfo(False, "Lora not found warning popup in webui"), })) diff --git a/extensions-builtin/Lora/ui_edit_user_metadata.py b/extensions-builtin/Lora/ui_edit_user_metadata.py index c70119090..3160aecfa 100644 --- a/extensions-builtin/Lora/ui_edit_user_metadata.py +++ b/extensions-builtin/Lora/ui_edit_user_metadata.py @@ -54,12 +54,13 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) self.slider_preferred_weight = None self.edit_notes = None - def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, notes): + def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, negative_text, notes): user_metadata = self.get_user_metadata(name) user_metadata["description"] = desc user_metadata["sd version"] = sd_version user_metadata["activation text"] = activation_text user_metadata["preferred weight"] = preferred_weight + user_metadata["negative text"] = negative_text user_metadata["notes"] = notes self.write_user_metadata(name, user_metadata) @@ -127,6 +128,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) gr.HighlightedText.update(value=gradio_tags, visible=True if tags else False), user_metadata.get('activation text', ''), float(user_metadata.get('preferred weight', 0.0)), + user_metadata.get('negative text', ''), gr.update(visible=True if tags else False), gr.update(value=self.generate_random_prompt_from_tags(tags), visible=True if tags else False), ] @@ -162,7 +164,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) self.taginfo = gr.HighlightedText(label="Training dataset tags") self.edit_activation_text = gr.Text(label='Activation text', info="Will be added to prompt along with Lora") self.slider_preferred_weight = gr.Slider(label='Preferred weight', info="Set to 0 to disable", minimum=0.0, maximum=2.0, step=0.01) - + self.edit_negative_text = gr.Text(label='Negative prompt', info="Will be added to negative prompts") with gr.Row() as row_random_prompt: with gr.Column(scale=8): random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False) @@ -198,6 +200,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) self.taginfo, self.edit_activation_text, self.slider_preferred_weight, + self.edit_negative_text, row_random_prompt, random_prompt, ] @@ -211,7 +214,9 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) self.select_sd_version, self.edit_activation_text, self.slider_preferred_weight, + self.edit_negative_text, self.edit_notes, ] + self.setup_save_handler(self.button_save, self.save_lora_user_metadata, edited_components) diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index 55409a782..e714fac46 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -17,6 +17,8 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): def create_item(self, name, index=None, enable_filter=True): lora_on_disk = networks.available_networks.get(name) + if lora_on_disk is None: + return path, ext = os.path.splitext(lora_on_disk.filename) @@ -43,6 +45,11 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): if activation_text: item["prompt"] += " + " + quote_js(" " + activation_text) + negative_prompt = item["user_metadata"].get("negative text") + item["negative_prompt"] = quote_js("") + if negative_prompt: + item["negative_prompt"] = quote_js('(' + negative_prompt + ':1)') + sd_version = item["user_metadata"].get("sd version") if sd_version in network.SdVersion.__members__: item["sd_version"] = sd_version @@ -66,9 +73,10 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): return item def list_items(self): - for index, name in enumerate(networks.available_networks): + # instantiate a list to protect against concurrent modification + names = list(networks.available_networks) + for index, name in enumerate(names): item = self.create_item(name, index) - if item is not None: yield item diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py index 167d2f64b..f799cb76d 100644 --- a/extensions-builtin/ScuNET/scripts/scunet_model.py +++ b/extensions-builtin/ScuNET/scripts/scunet_model.py @@ -3,14 +3,11 @@ import sys import PIL.Image import numpy as np import torch -from tqdm import tqdm import modules.upscaler from modules import devices, modelloader, script_callbacks, errors -from scunet_model_arch import SCUNet - -from modules.modelloader import load_file_from_url from modules.shared import opts +from modules.upscaler_utils import tiled_upscale_2 class UpscalerScuNET(modules.upscaler.Upscaler): @@ -42,47 +39,6 @@ class UpscalerScuNET(modules.upscaler.Upscaler): scalers.append(scaler_data2) self.scalers = scalers - @staticmethod - @torch.no_grad() - def tiled_inference(img, model): - # test the image tile by tile - h, w = img.shape[2:] - tile = opts.SCUNET_tile - tile_overlap = opts.SCUNET_tile_overlap - if tile == 0: - return model(img) - - device = devices.get_device_for('scunet') - assert tile % 8 == 0, "tile size should be a multiple of window_size" - sf = 1 - - stride = tile - tile_overlap - h_idx_list = list(range(0, h - tile, stride)) + [h - tile] - w_idx_list = list(range(0, w - tile, stride)) + [w - tile] - E = torch.zeros(1, 3, h * sf, w * sf, dtype=img.dtype, device=device) - W = torch.zeros_like(E, dtype=devices.dtype, device=device) - - with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="ScuNET tiles") as pbar: - for h_idx in h_idx_list: - - for w_idx in w_idx_list: - - in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile] - - out_patch = model(in_patch) - out_patch_mask = torch.ones_like(out_patch) - - E[ - ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf - ].add_(out_patch) - W[ - ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf - ].add_(out_patch_mask) - pbar.update(1) - output = E.div_(W) - - return output - def do_upscale(self, img: PIL.Image.Image, selected_file): devices.torch_gc() @@ -106,7 +62,16 @@ class UpscalerScuNET(modules.upscaler.Upscaler): _img[:, :, :h, :w] = torch_img # pad image torch_img = _img - torch_output = self.tiled_inference(torch_img, model).squeeze(0) + with torch.no_grad(): + torch_output = tiled_upscale_2( + torch_img, + model, + tile_size=opts.SCUNET_tile, + tile_overlap=opts.SCUNET_tile_overlap, + scale=1, + device=devices.get_device_for('scunet'), + desc="ScuNET tiles", + ).squeeze(0) torch_output = torch_output[:, :h * 1, :w * 1] # remove padding, if any np_output: np.ndarray = torch_output.float().cpu().clamp_(0, 1).numpy() del torch_img, torch_output @@ -120,17 +85,10 @@ class UpscalerScuNET(modules.upscaler.Upscaler): device = devices.get_device_for('scunet') if path.startswith("http"): # TODO: this doesn't use `path` at all? - filename = load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth") + filename = modelloader.load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth") else: filename = path - model = SCUNet(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64) - model.load_state_dict(torch.load(filename), strict=True) - model.eval() - for _, v in model.named_parameters(): - v.requires_grad = False - model = model.to(device) - - return model + return modelloader.load_spandrel_model(filename, device=device, expected_architecture='SCUNet') def on_ui_settings(): diff --git a/extensions-builtin/ScuNET/scunet_model_arch.py b/extensions-builtin/ScuNET/scunet_model_arch.py deleted file mode 100644 index b51a88062..000000000 --- a/extensions-builtin/ScuNET/scunet_model_arch.py +++ /dev/null @@ -1,268 +0,0 @@ -# -*- coding: utf-8 -*- -import numpy as np -import torch -import torch.nn as nn -from einops import rearrange -from einops.layers.torch import Rearrange -from timm.models.layers import trunc_normal_, DropPath - - -class WMSA(nn.Module): - """ Self-attention module in Swin Transformer - """ - - def __init__(self, input_dim, output_dim, head_dim, window_size, type): - super(WMSA, self).__init__() - self.input_dim = input_dim - self.output_dim = output_dim - self.head_dim = head_dim - self.scale = self.head_dim ** -0.5 - self.n_heads = input_dim // head_dim - self.window_size = window_size - self.type = type - self.embedding_layer = nn.Linear(self.input_dim, 3 * self.input_dim, bias=True) - - self.relative_position_params = nn.Parameter( - torch.zeros((2 * window_size - 1) * (2 * window_size - 1), self.n_heads)) - - self.linear = nn.Linear(self.input_dim, self.output_dim) - - trunc_normal_(self.relative_position_params, std=.02) - self.relative_position_params = torch.nn.Parameter( - self.relative_position_params.view(2 * window_size - 1, 2 * window_size - 1, self.n_heads).transpose(1, - 2).transpose( - 0, 1)) - - def generate_mask(self, h, w, p, shift): - """ generating the mask of SW-MSA - Args: - shift: shift parameters in CyclicShift. - Returns: - attn_mask: should be (1 1 w p p), - """ - # supporting square. - attn_mask = torch.zeros(h, w, p, p, p, p, dtype=torch.bool, device=self.relative_position_params.device) - if self.type == 'W': - return attn_mask - - s = p - shift - attn_mask[-1, :, :s, :, s:, :] = True - attn_mask[-1, :, s:, :, :s, :] = True - attn_mask[:, -1, :, :s, :, s:] = True - attn_mask[:, -1, :, s:, :, :s] = True - attn_mask = rearrange(attn_mask, 'w1 w2 p1 p2 p3 p4 -> 1 1 (w1 w2) (p1 p2) (p3 p4)') - return attn_mask - - def forward(self, x): - """ Forward pass of Window Multi-head Self-attention module. - Args: - x: input tensor with shape of [b h w c]; - attn_mask: attention mask, fill -inf where the value is True; - Returns: - output: tensor shape [b h w c] - """ - if self.type != 'W': - x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2)) - - x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size) - h_windows = x.size(1) - w_windows = x.size(2) - # square validation - # assert h_windows == w_windows - - x = rearrange(x, 'b w1 w2 p1 p2 c -> b (w1 w2) (p1 p2) c', p1=self.window_size, p2=self.window_size) - qkv = self.embedding_layer(x) - q, k, v = rearrange(qkv, 'b nw np (threeh c) -> threeh b nw np c', c=self.head_dim).chunk(3, dim=0) - sim = torch.einsum('hbwpc,hbwqc->hbwpq', q, k) * self.scale - # Adding learnable relative embedding - sim = sim + rearrange(self.relative_embedding(), 'h p q -> h 1 1 p q') - # Using Attn Mask to distinguish different subwindows. - if self.type != 'W': - attn_mask = self.generate_mask(h_windows, w_windows, self.window_size, shift=self.window_size // 2) - sim = sim.masked_fill_(attn_mask, float("-inf")) - - probs = nn.functional.softmax(sim, dim=-1) - output = torch.einsum('hbwij,hbwjc->hbwic', probs, v) - output = rearrange(output, 'h b w p c -> b w p (h c)') - output = self.linear(output) - output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size) - - if self.type != 'W': - output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2), dims=(1, 2)) - - return output - - def relative_embedding(self): - cord = torch.tensor(np.array([[i, j] for i in range(self.window_size) for j in range(self.window_size)])) - relation = cord[:, None, :] - cord[None, :, :] + self.window_size - 1 - # negative is allowed - return self.relative_position_params[:, relation[:, :, 0].long(), relation[:, :, 1].long()] - - -class Block(nn.Module): - def __init__(self, input_dim, output_dim, head_dim, window_size, drop_path, type='W', input_resolution=None): - """ SwinTransformer Block - """ - super(Block, self).__init__() - self.input_dim = input_dim - self.output_dim = output_dim - assert type in ['W', 'SW'] - self.type = type - if input_resolution <= window_size: - self.type = 'W' - - self.ln1 = nn.LayerNorm(input_dim) - self.msa = WMSA(input_dim, input_dim, head_dim, window_size, self.type) - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.ln2 = nn.LayerNorm(input_dim) - self.mlp = nn.Sequential( - nn.Linear(input_dim, 4 * input_dim), - nn.GELU(), - nn.Linear(4 * input_dim, output_dim), - ) - - def forward(self, x): - x = x + self.drop_path(self.msa(self.ln1(x))) - x = x + self.drop_path(self.mlp(self.ln2(x))) - return x - - -class ConvTransBlock(nn.Module): - def __init__(self, conv_dim, trans_dim, head_dim, window_size, drop_path, type='W', input_resolution=None): - """ SwinTransformer and Conv Block - """ - super(ConvTransBlock, self).__init__() - self.conv_dim = conv_dim - self.trans_dim = trans_dim - self.head_dim = head_dim - self.window_size = window_size - self.drop_path = drop_path - self.type = type - self.input_resolution = input_resolution - - assert self.type in ['W', 'SW'] - if self.input_resolution <= self.window_size: - self.type = 'W' - - self.trans_block = Block(self.trans_dim, self.trans_dim, self.head_dim, self.window_size, self.drop_path, - self.type, self.input_resolution) - self.conv1_1 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True) - self.conv1_2 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True) - - self.conv_block = nn.Sequential( - nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False), - nn.ReLU(True), - nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False) - ) - - def forward(self, x): - conv_x, trans_x = torch.split(self.conv1_1(x), (self.conv_dim, self.trans_dim), dim=1) - conv_x = self.conv_block(conv_x) + conv_x - trans_x = Rearrange('b c h w -> b h w c')(trans_x) - trans_x = self.trans_block(trans_x) - trans_x = Rearrange('b h w c -> b c h w')(trans_x) - res = self.conv1_2(torch.cat((conv_x, trans_x), dim=1)) - x = x + res - - return x - - -class SCUNet(nn.Module): - # def __init__(self, in_nc=3, config=[2, 2, 2, 2, 2, 2, 2], dim=64, drop_path_rate=0.0, input_resolution=256): - def __init__(self, in_nc=3, config=None, dim=64, drop_path_rate=0.0, input_resolution=256): - super(SCUNet, self).__init__() - if config is None: - config = [2, 2, 2, 2, 2, 2, 2] - self.config = config - self.dim = dim - self.head_dim = 32 - self.window_size = 8 - - # drop path rate for each layer - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(config))] - - self.m_head = [nn.Conv2d(in_nc, dim, 3, 1, 1, bias=False)] - - begin = 0 - self.m_down1 = [ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin], - 'W' if not i % 2 else 'SW', input_resolution) - for i in range(config[0])] + \ - [nn.Conv2d(dim, 2 * dim, 2, 2, 0, bias=False)] - - begin += config[0] - self.m_down2 = [ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin], - 'W' if not i % 2 else 'SW', input_resolution // 2) - for i in range(config[1])] + \ - [nn.Conv2d(2 * dim, 4 * dim, 2, 2, 0, bias=False)] - - begin += config[1] - self.m_down3 = [ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin], - 'W' if not i % 2 else 'SW', input_resolution // 4) - for i in range(config[2])] + \ - [nn.Conv2d(4 * dim, 8 * dim, 2, 2, 0, bias=False)] - - begin += config[2] - self.m_body = [ConvTransBlock(4 * dim, 4 * dim, self.head_dim, self.window_size, dpr[i + begin], - 'W' if not i % 2 else 'SW', input_resolution // 8) - for i in range(config[3])] - - begin += config[3] - self.m_up3 = [nn.ConvTranspose2d(8 * dim, 4 * dim, 2, 2, 0, bias=False), ] + \ - [ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin], - 'W' if not i % 2 else 'SW', input_resolution // 4) - for i in range(config[4])] - - begin += config[4] - self.m_up2 = [nn.ConvTranspose2d(4 * dim, 2 * dim, 2, 2, 0, bias=False), ] + \ - [ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin], - 'W' if not i % 2 else 'SW', input_resolution // 2) - for i in range(config[5])] - - begin += config[5] - self.m_up1 = [nn.ConvTranspose2d(2 * dim, dim, 2, 2, 0, bias=False), ] + \ - [ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin], - 'W' if not i % 2 else 'SW', input_resolution) - for i in range(config[6])] - - self.m_tail = [nn.Conv2d(dim, in_nc, 3, 1, 1, bias=False)] - - self.m_head = nn.Sequential(*self.m_head) - self.m_down1 = nn.Sequential(*self.m_down1) - self.m_down2 = nn.Sequential(*self.m_down2) - self.m_down3 = nn.Sequential(*self.m_down3) - self.m_body = nn.Sequential(*self.m_body) - self.m_up3 = nn.Sequential(*self.m_up3) - self.m_up2 = nn.Sequential(*self.m_up2) - self.m_up1 = nn.Sequential(*self.m_up1) - self.m_tail = nn.Sequential(*self.m_tail) - # self.apply(self._init_weights) - - def forward(self, x0): - - h, w = x0.size()[-2:] - paddingBottom = int(np.ceil(h / 64) * 64 - h) - paddingRight = int(np.ceil(w / 64) * 64 - w) - x0 = nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x0) - - x1 = self.m_head(x0) - x2 = self.m_down1(x1) - x3 = self.m_down2(x2) - x4 = self.m_down3(x3) - x = self.m_body(x4) - x = self.m_up3(x + x4) - x = self.m_up2(x + x3) - x = self.m_up1(x + x2) - x = self.m_tail(x + x1) - - x = x[..., :h, :w] - - return x - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index ae0d0e6a8..8a555c794 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -1,20 +1,18 @@ +import logging import sys -import platform import numpy as np import torch from PIL import Image -from tqdm import tqdm from modules import modelloader, devices, script_callbacks, shared -from modules.shared import opts, state -from swinir_model_arch import SwinIR -from swinir_model_arch_v2 import Swin2SR +from modules.shared import opts from modules.upscaler import Upscaler, UpscalerData +from modules.upscaler_utils import tiled_upscale_2 SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth" -device_swinir = devices.get_device_for('swinir') +logger = logging.getLogger(__name__) class UpscalerSwinIR(Upscaler): @@ -37,26 +35,29 @@ class UpscalerSwinIR(Upscaler): scalers.append(model_data) self.scalers = scalers - def do_upscale(self, img, model_file): - use_compile = hasattr(opts, 'SWIN_torch_compile') and opts.SWIN_torch_compile \ - and int(torch.__version__.split('.')[0]) >= 2 and platform.system() != "Windows" + def do_upscale(self, img: Image.Image, model_file: str) -> Image.Image: current_config = (model_file, opts.SWIN_tile) - if use_compile and self._cached_model_config == current_config: + device = self._get_device() + + if self._cached_model_config == current_config: model = self._cached_model else: - self._cached_model = None try: model = self.load_model(model_file) except Exception as e: print(f"Failed loading SwinIR model {model_file}: {e}", file=sys.stderr) return img - model = model.to(device_swinir, dtype=devices.dtype) - if use_compile: - model = torch.compile(model) - self._cached_model = model - self._cached_model_config = current_config - img = upscale(img, model) + self._cached_model = model + self._cached_model_config = current_config + + img = upscale( + img, + model, + tile=opts.SWIN_tile, + tile_overlap=opts.SWIN_tile_overlap, + device=device, + ) devices.torch_gc() return img @@ -69,69 +70,55 @@ class UpscalerSwinIR(Upscaler): ) else: filename = path - if filename.endswith(".v2.pth"): - model = Swin2SR( - upscale=scale, - in_chans=3, - img_size=64, - window_size=8, - img_range=1.0, - depths=[6, 6, 6, 6, 6, 6], - embed_dim=180, - num_heads=[6, 6, 6, 6, 6, 6], - mlp_ratio=2, - upsampler="nearest+conv", - resi_connection="1conv", - ) - params = None - else: - model = SwinIR( - upscale=scale, - in_chans=3, - img_size=64, - window_size=8, - img_range=1.0, - depths=[6, 6, 6, 6, 6, 6, 6, 6, 6], - embed_dim=240, - num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8], - mlp_ratio=2, - upsampler="nearest+conv", - resi_connection="3conv", - ) - params = "params_ema" - pretrained_model = torch.load(filename) - if params is not None: - model.load_state_dict(pretrained_model[params], strict=True) - else: - model.load_state_dict(pretrained_model, strict=True) - return model + model_descriptor = modelloader.load_spandrel_model( + filename, + device=self._get_device(), + dtype=devices.dtype, + expected_architecture="SwinIR", + ) + if getattr(opts, 'SWIN_torch_compile', False): + try: + model_descriptor.model.compile() + except Exception: + logger.warning("Failed to compile SwinIR model, fallback to JIT", exc_info=True) + return model_descriptor + + def _get_device(self): + return devices.get_device_for('swinir') def upscale( - img, - model, - tile=None, - tile_overlap=None, - window_size=8, - scale=4, + img, + model, + *, + tile: int, + tile_overlap: int, + window_size=8, + scale=4, + device, ): - tile = tile or opts.SWIN_tile - tile_overlap = tile_overlap or opts.SWIN_tile_overlap - img = np.array(img) img = img[:, :, ::-1] img = np.moveaxis(img, 2, 0) / 255 img = torch.from_numpy(img).float() - img = img.unsqueeze(0).to(device_swinir, dtype=devices.dtype) + img = img.unsqueeze(0).to(device, dtype=devices.dtype) with torch.no_grad(), devices.autocast(): _, _, h_old, w_old = img.size() h_pad = (h_old // window_size + 1) * window_size - h_old w_pad = (w_old // window_size + 1) * window_size - w_old img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :] img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad] - output = inference(img, model, tile, tile_overlap, window_size, scale) + output = tiled_upscale_2( + img, + model, + tile_size=tile, + tile_overlap=tile_overlap, + scale=scale, + device=device, + desc="SwinIR tiles", + ) output = output[..., : h_old * scale, : w_old * scale] output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() if output.ndim == 3: @@ -142,51 +129,12 @@ def upscale( return Image.fromarray(output, "RGB") -def inference(img, model, tile, tile_overlap, window_size, scale): - # test the image tile by tile - b, c, h, w = img.size() - tile = min(tile, h, w) - assert tile % window_size == 0, "tile size should be a multiple of window_size" - sf = scale - - stride = tile - tile_overlap - h_idx_list = list(range(0, h - tile, stride)) + [h - tile] - w_idx_list = list(range(0, w - tile, stride)) + [w - tile] - E = torch.zeros(b, c, h * sf, w * sf, dtype=devices.dtype, device=device_swinir).type_as(img) - W = torch.zeros_like(E, dtype=devices.dtype, device=device_swinir) - - with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar: - for h_idx in h_idx_list: - if state.interrupted or state.skipped: - break - - for w_idx in w_idx_list: - if state.interrupted or state.skipped: - break - - in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile] - out_patch = model(in_patch) - out_patch_mask = torch.ones_like(out_patch) - - E[ - ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf - ].add_(out_patch) - W[ - ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf - ].add_(out_patch_mask) - pbar.update(1) - output = E.div_(W) - - return output - - def on_ui_settings(): import gradio as gr shared.opts.add_option("SWIN_tile", shared.OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling"))) shared.opts.add_option("SWIN_tile_overlap", shared.OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}, section=('upscaling', "Upscaling"))) - if int(torch.__version__.split('.')[0]) >= 2 and platform.system() != "Windows": # torch.compile() require pytorch 2.0 or above, and not on Windows - shared.opts.add_option("SWIN_torch_compile", shared.OptionInfo(False, "Use torch.compile to accelerate SwinIR.", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")).info("Takes longer on first run")) + shared.opts.add_option("SWIN_torch_compile", shared.OptionInfo(False, "Use torch.compile to accelerate SwinIR.", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")).info("Takes longer on first run")) script_callbacks.on_ui_settings(on_ui_settings) diff --git a/extensions-builtin/SwinIR/swinir_model_arch.py b/extensions-builtin/SwinIR/swinir_model_arch.py deleted file mode 100644 index 93b932747..000000000 --- a/extensions-builtin/SwinIR/swinir_model_arch.py +++ /dev/null @@ -1,867 +0,0 @@ -# ----------------------------------------------------------------------------------- -# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257 -# Originally Written by Ze Liu, Modified by Jingyun Liang. -# ----------------------------------------------------------------------------------- - -import math -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.utils.checkpoint as checkpoint -from timm.models.layers import DropPath, to_2tuple, trunc_normal_ - - -class Mlp(nn.Module): - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - -def window_partition(x, window_size): - """ - Args: - x: (B, H, W, C) - window_size (int): window size - - Returns: - windows: (num_windows*B, window_size, window_size, C) - """ - B, H, W, C = x.shape - x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) - return windows - - -def window_reverse(windows, window_size, H, W): - """ - Args: - windows: (num_windows*B, window_size, window_size, C) - window_size (int): Window size - H (int): Height of image - W (int): Width of image - - Returns: - x: (B, H, W, C) - """ - B = int(windows.shape[0] / (H * W / window_size / window_size)) - x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) - return x - - -class WindowAttention(nn.Module): - r""" Window based multi-head self attention (W-MSA) module with relative position bias. - It supports both of shifted and non-shifted window. - - Args: - dim (int): Number of input channels. - window_size (tuple[int]): The height and width of the window. - num_heads (int): Number of attention heads. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set - attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 - proj_drop (float, optional): Dropout ratio of output. Default: 0.0 - """ - - def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): - - super().__init__() - self.dim = dim - self.window_size = window_size # Wh, Ww - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 - - # define a parameter table of relative position bias - self.relative_position_bias_table = nn.Parameter( - torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH - - # get pair-wise relative position index for each token inside the window - coords_h = torch.arange(self.window_size[0]) - coords_w = torch.arange(self.window_size[1]) - coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww - coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 - relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 - relative_coords[:, :, 1] += self.window_size[1] - 1 - relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 - relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww - self.register_buffer("relative_position_index", relative_position_index) - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - - self.proj_drop = nn.Dropout(proj_drop) - - trunc_normal_(self.relative_position_bias_table, std=.02) - self.softmax = nn.Softmax(dim=-1) - - def forward(self, x, mask=None): - """ - Args: - x: input features with shape of (num_windows*B, N, C) - mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None - """ - B_, N, C = x.shape - qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) - - q = q * self.scale - attn = (q @ k.transpose(-2, -1)) - - relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( - self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww - attn = attn + relative_position_bias.unsqueeze(0) - - if mask is not None: - nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) - attn = attn.view(-1, self.num_heads, N, N) - attn = self.softmax(attn) - else: - attn = self.softmax(attn) - - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B_, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x - - def extra_repr(self) -> str: - return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' - - def flops(self, N): - # calculate flops for 1 window with token length of N - flops = 0 - # qkv = self.qkv(x) - flops += N * self.dim * 3 * self.dim - # attn = (q @ k.transpose(-2, -1)) - flops += self.num_heads * N * (self.dim // self.num_heads) * N - # x = (attn @ v) - flops += self.num_heads * N * N * (self.dim // self.num_heads) - # x = self.proj(x) - flops += N * self.dim * self.dim - return flops - - -class SwinTransformerBlock(nn.Module): - r""" Swin Transformer Block. - - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - num_heads (int): Number of attention heads. - window_size (int): Window size. - shift_size (int): Shift size for SW-MSA. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float, optional): Stochastic depth rate. Default: 0.0 - act_layer (nn.Module, optional): Activation layer. Default: nn.GELU - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm): - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.num_heads = num_heads - self.window_size = window_size - self.shift_size = shift_size - self.mlp_ratio = mlp_ratio - if min(self.input_resolution) <= self.window_size: - # if window size is larger than input resolution, we don't partition windows - self.shift_size = 0 - self.window_size = min(self.input_resolution) - assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" - - self.norm1 = norm_layer(dim) - self.attn = WindowAttention( - dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, - qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) - - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - - if self.shift_size > 0: - attn_mask = self.calculate_mask(self.input_resolution) - else: - attn_mask = None - - self.register_buffer("attn_mask", attn_mask) - - def calculate_mask(self, x_size): - # calculate attention mask for SW-MSA - H, W = x_size - img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 - h_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - w_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - cnt = 0 - for h in h_slices: - for w in w_slices: - img_mask[:, h, w, :] = cnt - cnt += 1 - - mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 - mask_windows = mask_windows.view(-1, self.window_size * self.window_size) - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - - return attn_mask - - def forward(self, x, x_size): - H, W = x_size - B, L, C = x.shape - # assert L == H * W, "input feature has wrong size" - - shortcut = x - x = self.norm1(x) - x = x.view(B, H, W, C) - - # cyclic shift - if self.shift_size > 0: - shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) - else: - shifted_x = x - - # partition windows - x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C - x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C - - # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size - if self.input_resolution == x_size: - attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C - else: - attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) - - # merge windows - attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) - shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C - - # reverse cyclic shift - if self.shift_size > 0: - x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) - else: - x = shifted_x - x = x.view(B, H * W, C) - - # FFN - x = shortcut + self.drop_path(x) - x = x + self.drop_path(self.mlp(self.norm2(x))) - - return x - - def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ - f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" - - def flops(self): - flops = 0 - H, W = self.input_resolution - # norm1 - flops += self.dim * H * W - # W-MSA/SW-MSA - nW = H * W / self.window_size / self.window_size - flops += nW * self.attn.flops(self.window_size * self.window_size) - # mlp - flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio - # norm2 - flops += self.dim * H * W - return flops - - -class PatchMerging(nn.Module): - r""" Patch Merging Layer. - - Args: - input_resolution (tuple[int]): Resolution of input feature. - dim (int): Number of input channels. - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): - super().__init__() - self.input_resolution = input_resolution - self.dim = dim - self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) - self.norm = norm_layer(4 * dim) - - def forward(self, x): - """ - x: B, H*W, C - """ - H, W = self.input_resolution - B, L, C = x.shape - assert L == H * W, "input feature has wrong size" - assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." - - x = x.view(B, H, W, C) - - x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C - x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C - x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C - x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C - x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C - x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C - - x = self.norm(x) - x = self.reduction(x) - - return x - - def extra_repr(self) -> str: - return f"input_resolution={self.input_resolution}, dim={self.dim}" - - def flops(self): - H, W = self.input_resolution - flops = H * W * self.dim - flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim - return flops - - -class BasicLayer(nn.Module): - """ A basic Swin Transformer layer for one stage. - - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - depth (int): Number of blocks. - num_heads (int): Number of attention heads. - window_size (int): Local window size. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. - """ - - def __init__(self, dim, input_resolution, depth, num_heads, window_size, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): - - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.depth = depth - self.use_checkpoint = use_checkpoint - - # build blocks - self.blocks = nn.ModuleList([ - SwinTransformerBlock(dim=dim, input_resolution=input_resolution, - num_heads=num_heads, window_size=window_size, - shift_size=0 if (i % 2 == 0) else window_size // 2, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop, attn_drop=attn_drop, - drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, - norm_layer=norm_layer) - for i in range(depth)]) - - # patch merging layer - if downsample is not None: - self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) - else: - self.downsample = None - - def forward(self, x, x_size): - for blk in self.blocks: - if self.use_checkpoint: - x = checkpoint.checkpoint(blk, x, x_size) - else: - x = blk(x, x_size) - if self.downsample is not None: - x = self.downsample(x) - return x - - def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" - - def flops(self): - flops = 0 - for blk in self.blocks: - flops += blk.flops() - if self.downsample is not None: - flops += self.downsample.flops() - return flops - - -class RSTB(nn.Module): - """Residual Swin Transformer Block (RSTB). - - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - depth (int): Number of blocks. - num_heads (int): Number of attention heads. - window_size (int): Local window size. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. - img_size: Input image size. - patch_size: Patch size. - resi_connection: The convolutional block before residual connection. - """ - - def __init__(self, dim, input_resolution, depth, num_heads, window_size, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, - img_size=224, patch_size=4, resi_connection='1conv'): - super(RSTB, self).__init__() - - self.dim = dim - self.input_resolution = input_resolution - - self.residual_group = BasicLayer(dim=dim, - input_resolution=input_resolution, - depth=depth, - num_heads=num_heads, - window_size=window_size, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop, attn_drop=attn_drop, - drop_path=drop_path, - norm_layer=norm_layer, - downsample=downsample, - use_checkpoint=use_checkpoint) - - if resi_connection == '1conv': - self.conv = nn.Conv2d(dim, dim, 3, 1, 1) - elif resi_connection == '3conv': - # to save parameters and memory - self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(dim // 4, dim, 3, 1, 1)) - - self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, - norm_layer=None) - - self.patch_unembed = PatchUnEmbed( - img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, - norm_layer=None) - - def forward(self, x, x_size): - return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x - - def flops(self): - flops = 0 - flops += self.residual_group.flops() - H, W = self.input_resolution - flops += H * W * self.dim * self.dim * 9 - flops += self.patch_embed.flops() - flops += self.patch_unembed.flops() - - return flops - - -class PatchEmbed(nn.Module): - r""" Image to Patch Embedding - - Args: - img_size (int): Image size. Default: 224. - patch_size (int): Patch token size. Default: 4. - in_chans (int): Number of input image channels. Default: 3. - embed_dim (int): Number of linear projection output channels. Default: 96. - norm_layer (nn.Module, optional): Normalization layer. Default: None - """ - - def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): - super().__init__() - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] - self.img_size = img_size - self.patch_size = patch_size - self.patches_resolution = patches_resolution - self.num_patches = patches_resolution[0] * patches_resolution[1] - - self.in_chans = in_chans - self.embed_dim = embed_dim - - if norm_layer is not None: - self.norm = norm_layer(embed_dim) - else: - self.norm = None - - def forward(self, x): - x = x.flatten(2).transpose(1, 2) # B Ph*Pw C - if self.norm is not None: - x = self.norm(x) - return x - - def flops(self): - flops = 0 - H, W = self.img_size - if self.norm is not None: - flops += H * W * self.embed_dim - return flops - - -class PatchUnEmbed(nn.Module): - r""" Image to Patch Unembedding - - Args: - img_size (int): Image size. Default: 224. - patch_size (int): Patch token size. Default: 4. - in_chans (int): Number of input image channels. Default: 3. - embed_dim (int): Number of linear projection output channels. Default: 96. - norm_layer (nn.Module, optional): Normalization layer. Default: None - """ - - def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): - super().__init__() - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] - self.img_size = img_size - self.patch_size = patch_size - self.patches_resolution = patches_resolution - self.num_patches = patches_resolution[0] * patches_resolution[1] - - self.in_chans = in_chans - self.embed_dim = embed_dim - - def forward(self, x, x_size): - B, HW, C = x.shape - x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C - return x - - def flops(self): - flops = 0 - return flops - - -class Upsample(nn.Sequential): - """Upsample module. - - Args: - scale (int): Scale factor. Supported scales: 2^n and 3. - num_feat (int): Channel number of intermediate features. - """ - - def __init__(self, scale, num_feat): - m = [] - if (scale & (scale - 1)) == 0: # scale = 2^n - for _ in range(int(math.log(scale, 2))): - m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) - m.append(nn.PixelShuffle(2)) - elif scale == 3: - m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) - m.append(nn.PixelShuffle(3)) - else: - raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') - super(Upsample, self).__init__(*m) - - -class UpsampleOneStep(nn.Sequential): - """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) - Used in lightweight SR to save parameters. - - Args: - scale (int): Scale factor. Supported scales: 2^n and 3. - num_feat (int): Channel number of intermediate features. - - """ - - def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): - self.num_feat = num_feat - self.input_resolution = input_resolution - m = [] - m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) - m.append(nn.PixelShuffle(scale)) - super(UpsampleOneStep, self).__init__(*m) - - def flops(self): - H, W = self.input_resolution - flops = H * W * self.num_feat * 3 * 9 - return flops - - -class SwinIR(nn.Module): - r""" SwinIR - A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer. - - Args: - img_size (int | tuple(int)): Input image size. Default 64 - patch_size (int | tuple(int)): Patch size. Default: 1 - in_chans (int): Number of input image channels. Default: 3 - embed_dim (int): Patch embedding dimension. Default: 96 - depths (tuple(int)): Depth of each Swin Transformer layer. - num_heads (tuple(int)): Number of attention heads in different layers. - window_size (int): Window size. Default: 7 - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 - qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None - drop_rate (float): Dropout rate. Default: 0 - attn_drop_rate (float): Attention dropout rate. Default: 0 - drop_path_rate (float): Stochastic depth rate. Default: 0.1 - norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. - ape (bool): If True, add absolute position embedding to the patch embedding. Default: False - patch_norm (bool): If True, add normalization after patch embedding. Default: True - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False - upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction - img_range: Image range. 1. or 255. - upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None - resi_connection: The convolutional block before residual connection. '1conv'/'3conv' - """ - - def __init__(self, img_size=64, patch_size=1, in_chans=3, - embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6), - window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, - norm_layer=nn.LayerNorm, ape=False, patch_norm=True, - use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', - **kwargs): - super(SwinIR, self).__init__() - num_in_ch = in_chans - num_out_ch = in_chans - num_feat = 64 - self.img_range = img_range - if in_chans == 3: - rgb_mean = (0.4488, 0.4371, 0.4040) - self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) - else: - self.mean = torch.zeros(1, 1, 1, 1) - self.upscale = upscale - self.upsampler = upsampler - self.window_size = window_size - - ##################################################################################################### - ################################### 1, shallow feature extraction ################################### - self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) - - ##################################################################################################### - ################################### 2, deep feature extraction ###################################### - self.num_layers = len(depths) - self.embed_dim = embed_dim - self.ape = ape - self.patch_norm = patch_norm - self.num_features = embed_dim - self.mlp_ratio = mlp_ratio - - # split image into non-overlapping patches - self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, - norm_layer=norm_layer if self.patch_norm else None) - num_patches = self.patch_embed.num_patches - patches_resolution = self.patch_embed.patches_resolution - self.patches_resolution = patches_resolution - - # merge non-overlapping patches into image - self.patch_unembed = PatchUnEmbed( - img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, - norm_layer=norm_layer if self.patch_norm else None) - - # absolute position embedding - if self.ape: - self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) - trunc_normal_(self.absolute_pos_embed, std=.02) - - self.pos_drop = nn.Dropout(p=drop_rate) - - # stochastic depth - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule - - # build Residual Swin Transformer blocks (RSTB) - self.layers = nn.ModuleList() - for i_layer in range(self.num_layers): - layer = RSTB(dim=embed_dim, - input_resolution=(patches_resolution[0], - patches_resolution[1]), - depth=depths[i_layer], - num_heads=num_heads[i_layer], - window_size=window_size, - mlp_ratio=self.mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, - drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results - norm_layer=norm_layer, - downsample=None, - use_checkpoint=use_checkpoint, - img_size=img_size, - patch_size=patch_size, - resi_connection=resi_connection - - ) - self.layers.append(layer) - self.norm = norm_layer(self.num_features) - - # build the last conv layer in deep feature extraction - if resi_connection == '1conv': - self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) - elif resi_connection == '3conv': - # to save parameters and memory - self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) - - ##################################################################################################### - ################################ 3, high quality image reconstruction ################################ - if self.upsampler == 'pixelshuffle': - # for classical SR - self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.upsample = Upsample(upscale, num_feat) - self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) - elif self.upsampler == 'pixelshuffledirect': - # for lightweight SR (to save parameters) - self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, - (patches_resolution[0], patches_resolution[1])) - elif self.upsampler == 'nearest+conv': - # for real-world SR (less artifacts) - self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) - if self.upscale == 4: - self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) - self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) - self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) - self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - else: - # for image denoising and JPEG compression artifact reduction - self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) - - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - @torch.jit.ignore - def no_weight_decay(self): - return {'absolute_pos_embed'} - - @torch.jit.ignore - def no_weight_decay_keywords(self): - return {'relative_position_bias_table'} - - def check_image_size(self, x): - _, _, h, w = x.size() - mod_pad_h = (self.window_size - h % self.window_size) % self.window_size - mod_pad_w = (self.window_size - w % self.window_size) % self.window_size - x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') - return x - - def forward_features(self, x): - x_size = (x.shape[2], x.shape[3]) - x = self.patch_embed(x) - if self.ape: - x = x + self.absolute_pos_embed - x = self.pos_drop(x) - - for layer in self.layers: - x = layer(x, x_size) - - x = self.norm(x) # B L C - x = self.patch_unembed(x, x_size) - - return x - - def forward(self, x): - H, W = x.shape[2:] - x = self.check_image_size(x) - - self.mean = self.mean.type_as(x) - x = (x - self.mean) * self.img_range - - if self.upsampler == 'pixelshuffle': - # for classical SR - x = self.conv_first(x) - x = self.conv_after_body(self.forward_features(x)) + x - x = self.conv_before_upsample(x) - x = self.conv_last(self.upsample(x)) - elif self.upsampler == 'pixelshuffledirect': - # for lightweight SR - x = self.conv_first(x) - x = self.conv_after_body(self.forward_features(x)) + x - x = self.upsample(x) - elif self.upsampler == 'nearest+conv': - # for real-world SR - x = self.conv_first(x) - x = self.conv_after_body(self.forward_features(x)) + x - x = self.conv_before_upsample(x) - x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) - if self.upscale == 4: - x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) - x = self.conv_last(self.lrelu(self.conv_hr(x))) - else: - # for image denoising and JPEG compression artifact reduction - x_first = self.conv_first(x) - res = self.conv_after_body(self.forward_features(x_first)) + x_first - x = x + self.conv_last(res) - - x = x / self.img_range + self.mean - - return x[:, :, :H*self.upscale, :W*self.upscale] - - def flops(self): - flops = 0 - H, W = self.patches_resolution - flops += H * W * 3 * self.embed_dim * 9 - flops += self.patch_embed.flops() - for layer in self.layers: - flops += layer.flops() - flops += H * W * 3 * self.embed_dim * self.embed_dim - flops += self.upsample.flops() - return flops - - -if __name__ == '__main__': - upscale = 4 - window_size = 8 - height = (1024 // upscale // window_size + 1) * window_size - width = (720 // upscale // window_size + 1) * window_size - model = SwinIR(upscale=2, img_size=(height, width), - window_size=window_size, img_range=1., depths=[6, 6, 6, 6], - embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') - print(model) - print(height, width, model.flops() / 1e9) - - x = torch.randn((1, 3, height, width)) - x = model(x) - print(x.shape) diff --git a/extensions-builtin/SwinIR/swinir_model_arch_v2.py b/extensions-builtin/SwinIR/swinir_model_arch_v2.py deleted file mode 100644 index dad22cca2..000000000 --- a/extensions-builtin/SwinIR/swinir_model_arch_v2.py +++ /dev/null @@ -1,1017 +0,0 @@ -# ----------------------------------------------------------------------------------- -# Swin2SR: Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration, https://arxiv.org/abs/ -# Written by Conde and Choi et al. -# ----------------------------------------------------------------------------------- - -import math -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.utils.checkpoint as checkpoint -from timm.models.layers import DropPath, to_2tuple, trunc_normal_ - - -class Mlp(nn.Module): - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - -def window_partition(x, window_size): - """ - Args: - x: (B, H, W, C) - window_size (int): window size - Returns: - windows: (num_windows*B, window_size, window_size, C) - """ - B, H, W, C = x.shape - x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) - return windows - - -def window_reverse(windows, window_size, H, W): - """ - Args: - windows: (num_windows*B, window_size, window_size, C) - window_size (int): Window size - H (int): Height of image - W (int): Width of image - Returns: - x: (B, H, W, C) - """ - B = int(windows.shape[0] / (H * W / window_size / window_size)) - x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) - return x - -class WindowAttention(nn.Module): - r""" Window based multi-head self attention (W-MSA) module with relative position bias. - It supports both of shifted and non-shifted window. - Args: - dim (int): Number of input channels. - window_size (tuple[int]): The height and width of the window. - num_heads (int): Number of attention heads. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 - proj_drop (float, optional): Dropout ratio of output. Default: 0.0 - pretrained_window_size (tuple[int]): The height and width of the window in pre-training. - """ - - def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0., - pretrained_window_size=(0, 0)): - - super().__init__() - self.dim = dim - self.window_size = window_size # Wh, Ww - self.pretrained_window_size = pretrained_window_size - self.num_heads = num_heads - - self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True) - - # mlp to generate continuous relative position bias - self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True), - nn.ReLU(inplace=True), - nn.Linear(512, num_heads, bias=False)) - - # get relative_coords_table - relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32) - relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32) - relative_coords_table = torch.stack( - torch.meshgrid([relative_coords_h, - relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2 - if pretrained_window_size[0] > 0: - relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1) - relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1) - else: - relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1) - relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1) - relative_coords_table *= 8 # normalize to -8, 8 - relative_coords_table = torch.sign(relative_coords_table) * torch.log2( - torch.abs(relative_coords_table) + 1.0) / np.log2(8) - - self.register_buffer("relative_coords_table", relative_coords_table) - - # get pair-wise relative position index for each token inside the window - coords_h = torch.arange(self.window_size[0]) - coords_w = torch.arange(self.window_size[1]) - coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww - coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 - relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 - relative_coords[:, :, 1] += self.window_size[1] - 1 - relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 - relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww - self.register_buffer("relative_position_index", relative_position_index) - - self.qkv = nn.Linear(dim, dim * 3, bias=False) - if qkv_bias: - self.q_bias = nn.Parameter(torch.zeros(dim)) - self.v_bias = nn.Parameter(torch.zeros(dim)) - else: - self.q_bias = None - self.v_bias = None - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - self.softmax = nn.Softmax(dim=-1) - - def forward(self, x, mask=None): - """ - Args: - x: input features with shape of (num_windows*B, N, C) - mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None - """ - B_, N, C = x.shape - qkv_bias = None - if self.q_bias is not None: - qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) - qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) - qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) - - # cosine attention - attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)) - logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01)).to(self.logit_scale.device)).exp() - attn = attn * logit_scale - - relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads) - relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view( - self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww - relative_position_bias = 16 * torch.sigmoid(relative_position_bias) - attn = attn + relative_position_bias.unsqueeze(0) - - if mask is not None: - nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) - attn = attn.view(-1, self.num_heads, N, N) - attn = self.softmax(attn) - else: - attn = self.softmax(attn) - - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B_, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x - - def extra_repr(self) -> str: - return f'dim={self.dim}, window_size={self.window_size}, ' \ - f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}' - - def flops(self, N): - # calculate flops for 1 window with token length of N - flops = 0 - # qkv = self.qkv(x) - flops += N * self.dim * 3 * self.dim - # attn = (q @ k.transpose(-2, -1)) - flops += self.num_heads * N * (self.dim // self.num_heads) * N - # x = (attn @ v) - flops += self.num_heads * N * N * (self.dim // self.num_heads) - # x = self.proj(x) - flops += N * self.dim * self.dim - return flops - -class SwinTransformerBlock(nn.Module): - r""" Swin Transformer Block. - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resulotion. - num_heads (int): Number of attention heads. - window_size (int): Window size. - shift_size (int): Shift size for SW-MSA. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float, optional): Stochastic depth rate. Default: 0.0 - act_layer (nn.Module, optional): Activation layer. Default: nn.GELU - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - pretrained_window_size (int): Window size in pre-training. - """ - - def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, - mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0): - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.num_heads = num_heads - self.window_size = window_size - self.shift_size = shift_size - self.mlp_ratio = mlp_ratio - if min(self.input_resolution) <= self.window_size: - # if window size is larger than input resolution, we don't partition windows - self.shift_size = 0 - self.window_size = min(self.input_resolution) - assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" - - self.norm1 = norm_layer(dim) - self.attn = WindowAttention( - dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, - qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, - pretrained_window_size=to_2tuple(pretrained_window_size)) - - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - - if self.shift_size > 0: - attn_mask = self.calculate_mask(self.input_resolution) - else: - attn_mask = None - - self.register_buffer("attn_mask", attn_mask) - - def calculate_mask(self, x_size): - # calculate attention mask for SW-MSA - H, W = x_size - img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 - h_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - w_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - cnt = 0 - for h in h_slices: - for w in w_slices: - img_mask[:, h, w, :] = cnt - cnt += 1 - - mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 - mask_windows = mask_windows.view(-1, self.window_size * self.window_size) - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - - return attn_mask - - def forward(self, x, x_size): - H, W = x_size - B, L, C = x.shape - #assert L == H * W, "input feature has wrong size" - - shortcut = x - x = x.view(B, H, W, C) - - # cyclic shift - if self.shift_size > 0: - shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) - else: - shifted_x = x - - # partition windows - x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C - x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C - - # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size - if self.input_resolution == x_size: - attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C - else: - attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) - - # merge windows - attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) - shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C - - # reverse cyclic shift - if self.shift_size > 0: - x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) - else: - x = shifted_x - x = x.view(B, H * W, C) - x = shortcut + self.drop_path(self.norm1(x)) - - # FFN - x = x + self.drop_path(self.norm2(self.mlp(x))) - - return x - - def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ - f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" - - def flops(self): - flops = 0 - H, W = self.input_resolution - # norm1 - flops += self.dim * H * W - # W-MSA/SW-MSA - nW = H * W / self.window_size / self.window_size - flops += nW * self.attn.flops(self.window_size * self.window_size) - # mlp - flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio - # norm2 - flops += self.dim * H * W - return flops - -class PatchMerging(nn.Module): - r""" Patch Merging Layer. - Args: - input_resolution (tuple[int]): Resolution of input feature. - dim (int): Number of input channels. - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): - super().__init__() - self.input_resolution = input_resolution - self.dim = dim - self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) - self.norm = norm_layer(2 * dim) - - def forward(self, x): - """ - x: B, H*W, C - """ - H, W = self.input_resolution - B, L, C = x.shape - assert L == H * W, "input feature has wrong size" - assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." - - x = x.view(B, H, W, C) - - x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C - x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C - x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C - x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C - x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C - x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C - - x = self.reduction(x) - x = self.norm(x) - - return x - - def extra_repr(self) -> str: - return f"input_resolution={self.input_resolution}, dim={self.dim}" - - def flops(self): - H, W = self.input_resolution - flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim - flops += H * W * self.dim // 2 - return flops - -class BasicLayer(nn.Module): - """ A basic Swin Transformer layer for one stage. - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - depth (int): Number of blocks. - num_heads (int): Number of attention heads. - window_size (int): Local window size. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. - pretrained_window_size (int): Local window size in pre-training. - """ - - def __init__(self, dim, input_resolution, depth, num_heads, window_size, - mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., - drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, - pretrained_window_size=0): - - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.depth = depth - self.use_checkpoint = use_checkpoint - - # build blocks - self.blocks = nn.ModuleList([ - SwinTransformerBlock(dim=dim, input_resolution=input_resolution, - num_heads=num_heads, window_size=window_size, - shift_size=0 if (i % 2 == 0) else window_size // 2, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - drop=drop, attn_drop=attn_drop, - drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, - norm_layer=norm_layer, - pretrained_window_size=pretrained_window_size) - for i in range(depth)]) - - # patch merging layer - if downsample is not None: - self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) - else: - self.downsample = None - - def forward(self, x, x_size): - for blk in self.blocks: - if self.use_checkpoint: - x = checkpoint.checkpoint(blk, x, x_size) - else: - x = blk(x, x_size) - if self.downsample is not None: - x = self.downsample(x) - return x - - def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" - - def flops(self): - flops = 0 - for blk in self.blocks: - flops += blk.flops() - if self.downsample is not None: - flops += self.downsample.flops() - return flops - - def _init_respostnorm(self): - for blk in self.blocks: - nn.init.constant_(blk.norm1.bias, 0) - nn.init.constant_(blk.norm1.weight, 0) - nn.init.constant_(blk.norm2.bias, 0) - nn.init.constant_(blk.norm2.weight, 0) - -class PatchEmbed(nn.Module): - r""" Image to Patch Embedding - Args: - img_size (int): Image size. Default: 224. - patch_size (int): Patch token size. Default: 4. - in_chans (int): Number of input image channels. Default: 3. - embed_dim (int): Number of linear projection output channels. Default: 96. - norm_layer (nn.Module, optional): Normalization layer. Default: None - """ - - def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): - super().__init__() - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] - self.img_size = img_size - self.patch_size = patch_size - self.patches_resolution = patches_resolution - self.num_patches = patches_resolution[0] * patches_resolution[1] - - self.in_chans = in_chans - self.embed_dim = embed_dim - - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) - if norm_layer is not None: - self.norm = norm_layer(embed_dim) - else: - self.norm = None - - def forward(self, x): - B, C, H, W = x.shape - # FIXME look at relaxing size constraints - # assert H == self.img_size[0] and W == self.img_size[1], - # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." - x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C - if self.norm is not None: - x = self.norm(x) - return x - - def flops(self): - Ho, Wo = self.patches_resolution - flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) - if self.norm is not None: - flops += Ho * Wo * self.embed_dim - return flops - -class RSTB(nn.Module): - """Residual Swin Transformer Block (RSTB). - - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - depth (int): Number of blocks. - num_heads (int): Number of attention heads. - window_size (int): Local window size. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. - img_size: Input image size. - patch_size: Patch size. - resi_connection: The convolutional block before residual connection. - """ - - def __init__(self, dim, input_resolution, depth, num_heads, window_size, - mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., - drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, - img_size=224, patch_size=4, resi_connection='1conv'): - super(RSTB, self).__init__() - - self.dim = dim - self.input_resolution = input_resolution - - self.residual_group = BasicLayer(dim=dim, - input_resolution=input_resolution, - depth=depth, - num_heads=num_heads, - window_size=window_size, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - drop=drop, attn_drop=attn_drop, - drop_path=drop_path, - norm_layer=norm_layer, - downsample=downsample, - use_checkpoint=use_checkpoint) - - if resi_connection == '1conv': - self.conv = nn.Conv2d(dim, dim, 3, 1, 1) - elif resi_connection == '3conv': - # to save parameters and memory - self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(dim // 4, dim, 3, 1, 1)) - - self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim, - norm_layer=None) - - self.patch_unembed = PatchUnEmbed( - img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim, - norm_layer=None) - - def forward(self, x, x_size): - return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x - - def flops(self): - flops = 0 - flops += self.residual_group.flops() - H, W = self.input_resolution - flops += H * W * self.dim * self.dim * 9 - flops += self.patch_embed.flops() - flops += self.patch_unembed.flops() - - return flops - -class PatchUnEmbed(nn.Module): - r""" Image to Patch Unembedding - - Args: - img_size (int): Image size. Default: 224. - patch_size (int): Patch token size. Default: 4. - in_chans (int): Number of input image channels. Default: 3. - embed_dim (int): Number of linear projection output channels. Default: 96. - norm_layer (nn.Module, optional): Normalization layer. Default: None - """ - - def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): - super().__init__() - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] - self.img_size = img_size - self.patch_size = patch_size - self.patches_resolution = patches_resolution - self.num_patches = patches_resolution[0] * patches_resolution[1] - - self.in_chans = in_chans - self.embed_dim = embed_dim - - def forward(self, x, x_size): - B, HW, C = x.shape - x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C - return x - - def flops(self): - flops = 0 - return flops - - -class Upsample(nn.Sequential): - """Upsample module. - - Args: - scale (int): Scale factor. Supported scales: 2^n and 3. - num_feat (int): Channel number of intermediate features. - """ - - def __init__(self, scale, num_feat): - m = [] - if (scale & (scale - 1)) == 0: # scale = 2^n - for _ in range(int(math.log(scale, 2))): - m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) - m.append(nn.PixelShuffle(2)) - elif scale == 3: - m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) - m.append(nn.PixelShuffle(3)) - else: - raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') - super(Upsample, self).__init__(*m) - -class Upsample_hf(nn.Sequential): - """Upsample module. - - Args: - scale (int): Scale factor. Supported scales: 2^n and 3. - num_feat (int): Channel number of intermediate features. - """ - - def __init__(self, scale, num_feat): - m = [] - if (scale & (scale - 1)) == 0: # scale = 2^n - for _ in range(int(math.log(scale, 2))): - m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) - m.append(nn.PixelShuffle(2)) - elif scale == 3: - m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) - m.append(nn.PixelShuffle(3)) - else: - raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') - super(Upsample_hf, self).__init__(*m) - - -class UpsampleOneStep(nn.Sequential): - """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) - Used in lightweight SR to save parameters. - - Args: - scale (int): Scale factor. Supported scales: 2^n and 3. - num_feat (int): Channel number of intermediate features. - - """ - - def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): - self.num_feat = num_feat - self.input_resolution = input_resolution - m = [] - m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) - m.append(nn.PixelShuffle(scale)) - super(UpsampleOneStep, self).__init__(*m) - - def flops(self): - H, W = self.input_resolution - flops = H * W * self.num_feat * 3 * 9 - return flops - - - -class Swin2SR(nn.Module): - r""" Swin2SR - A PyTorch impl of : `Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration`. - - Args: - img_size (int | tuple(int)): Input image size. Default 64 - patch_size (int | tuple(int)): Patch size. Default: 1 - in_chans (int): Number of input image channels. Default: 3 - embed_dim (int): Patch embedding dimension. Default: 96 - depths (tuple(int)): Depth of each Swin Transformer layer. - num_heads (tuple(int)): Number of attention heads in different layers. - window_size (int): Window size. Default: 7 - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 - qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True - drop_rate (float): Dropout rate. Default: 0 - attn_drop_rate (float): Attention dropout rate. Default: 0 - drop_path_rate (float): Stochastic depth rate. Default: 0.1 - norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. - ape (bool): If True, add absolute position embedding to the patch embedding. Default: False - patch_norm (bool): If True, add normalization after patch embedding. Default: True - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False - upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction - img_range: Image range. 1. or 255. - upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None - resi_connection: The convolutional block before residual connection. '1conv'/'3conv' - """ - - def __init__(self, img_size=64, patch_size=1, in_chans=3, - embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6), - window_size=7, mlp_ratio=4., qkv_bias=True, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, - norm_layer=nn.LayerNorm, ape=False, patch_norm=True, - use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', - **kwargs): - super(Swin2SR, self).__init__() - num_in_ch = in_chans - num_out_ch = in_chans - num_feat = 64 - self.img_range = img_range - if in_chans == 3: - rgb_mean = (0.4488, 0.4371, 0.4040) - self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) - else: - self.mean = torch.zeros(1, 1, 1, 1) - self.upscale = upscale - self.upsampler = upsampler - self.window_size = window_size - - ##################################################################################################### - ################################### 1, shallow feature extraction ################################### - self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) - - ##################################################################################################### - ################################### 2, deep feature extraction ###################################### - self.num_layers = len(depths) - self.embed_dim = embed_dim - self.ape = ape - self.patch_norm = patch_norm - self.num_features = embed_dim - self.mlp_ratio = mlp_ratio - - # split image into non-overlapping patches - self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, - norm_layer=norm_layer if self.patch_norm else None) - num_patches = self.patch_embed.num_patches - patches_resolution = self.patch_embed.patches_resolution - self.patches_resolution = patches_resolution - - # merge non-overlapping patches into image - self.patch_unembed = PatchUnEmbed( - img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, - norm_layer=norm_layer if self.patch_norm else None) - - # absolute position embedding - if self.ape: - self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) - trunc_normal_(self.absolute_pos_embed, std=.02) - - self.pos_drop = nn.Dropout(p=drop_rate) - - # stochastic depth - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule - - # build Residual Swin Transformer blocks (RSTB) - self.layers = nn.ModuleList() - for i_layer in range(self.num_layers): - layer = RSTB(dim=embed_dim, - input_resolution=(patches_resolution[0], - patches_resolution[1]), - depth=depths[i_layer], - num_heads=num_heads[i_layer], - window_size=window_size, - mlp_ratio=self.mlp_ratio, - qkv_bias=qkv_bias, - drop=drop_rate, attn_drop=attn_drop_rate, - drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results - norm_layer=norm_layer, - downsample=None, - use_checkpoint=use_checkpoint, - img_size=img_size, - patch_size=patch_size, - resi_connection=resi_connection - - ) - self.layers.append(layer) - - if self.upsampler == 'pixelshuffle_hf': - self.layers_hf = nn.ModuleList() - for i_layer in range(self.num_layers): - layer = RSTB(dim=embed_dim, - input_resolution=(patches_resolution[0], - patches_resolution[1]), - depth=depths[i_layer], - num_heads=num_heads[i_layer], - window_size=window_size, - mlp_ratio=self.mlp_ratio, - qkv_bias=qkv_bias, - drop=drop_rate, attn_drop=attn_drop_rate, - drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results - norm_layer=norm_layer, - downsample=None, - use_checkpoint=use_checkpoint, - img_size=img_size, - patch_size=patch_size, - resi_connection=resi_connection - - ) - self.layers_hf.append(layer) - - self.norm = norm_layer(self.num_features) - - # build the last conv layer in deep feature extraction - if resi_connection == '1conv': - self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) - elif resi_connection == '3conv': - # to save parameters and memory - self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) - - ##################################################################################################### - ################################ 3, high quality image reconstruction ################################ - if self.upsampler == 'pixelshuffle': - # for classical SR - self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.upsample = Upsample(upscale, num_feat) - self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) - elif self.upsampler == 'pixelshuffle_aux': - self.conv_bicubic = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) - self.conv_before_upsample = nn.Sequential( - nn.Conv2d(embed_dim, num_feat, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) - self.conv_after_aux = nn.Sequential( - nn.Conv2d(3, num_feat, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.upsample = Upsample(upscale, num_feat) - self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) - - elif self.upsampler == 'pixelshuffle_hf': - self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.upsample = Upsample(upscale, num_feat) - self.upsample_hf = Upsample_hf(upscale, num_feat) - self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) - self.conv_first_hf = nn.Sequential(nn.Conv2d(num_feat, embed_dim, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.conv_after_body_hf = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) - self.conv_before_upsample_hf = nn.Sequential( - nn.Conv2d(embed_dim, num_feat, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) - - elif self.upsampler == 'pixelshuffledirect': - # for lightweight SR (to save parameters) - self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, - (patches_resolution[0], patches_resolution[1])) - elif self.upsampler == 'nearest+conv': - # for real-world SR (less artifacts) - assert self.upscale == 4, 'only support x4 now.' - self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) - self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) - self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) - self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) - self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - else: - # for image denoising and JPEG compression artifact reduction - self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) - - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - @torch.jit.ignore - def no_weight_decay(self): - return {'absolute_pos_embed'} - - @torch.jit.ignore - def no_weight_decay_keywords(self): - return {'relative_position_bias_table'} - - def check_image_size(self, x): - _, _, h, w = x.size() - mod_pad_h = (self.window_size - h % self.window_size) % self.window_size - mod_pad_w = (self.window_size - w % self.window_size) % self.window_size - x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') - return x - - def forward_features(self, x): - x_size = (x.shape[2], x.shape[3]) - x = self.patch_embed(x) - if self.ape: - x = x + self.absolute_pos_embed - x = self.pos_drop(x) - - for layer in self.layers: - x = layer(x, x_size) - - x = self.norm(x) # B L C - x = self.patch_unembed(x, x_size) - - return x - - def forward_features_hf(self, x): - x_size = (x.shape[2], x.shape[3]) - x = self.patch_embed(x) - if self.ape: - x = x + self.absolute_pos_embed - x = self.pos_drop(x) - - for layer in self.layers_hf: - x = layer(x, x_size) - - x = self.norm(x) # B L C - x = self.patch_unembed(x, x_size) - - return x - - def forward(self, x): - H, W = x.shape[2:] - x = self.check_image_size(x) - - self.mean = self.mean.type_as(x) - x = (x - self.mean) * self.img_range - - if self.upsampler == 'pixelshuffle': - # for classical SR - x = self.conv_first(x) - x = self.conv_after_body(self.forward_features(x)) + x - x = self.conv_before_upsample(x) - x = self.conv_last(self.upsample(x)) - elif self.upsampler == 'pixelshuffle_aux': - bicubic = F.interpolate(x, size=(H * self.upscale, W * self.upscale), mode='bicubic', align_corners=False) - bicubic = self.conv_bicubic(bicubic) - x = self.conv_first(x) - x = self.conv_after_body(self.forward_features(x)) + x - x = self.conv_before_upsample(x) - aux = self.conv_aux(x) # b, 3, LR_H, LR_W - x = self.conv_after_aux(aux) - x = self.upsample(x)[:, :, :H * self.upscale, :W * self.upscale] + bicubic[:, :, :H * self.upscale, :W * self.upscale] - x = self.conv_last(x) - aux = aux / self.img_range + self.mean - elif self.upsampler == 'pixelshuffle_hf': - # for classical SR with HF - x = self.conv_first(x) - x = self.conv_after_body(self.forward_features(x)) + x - x_before = self.conv_before_upsample(x) - x_out = self.conv_last(self.upsample(x_before)) - - x_hf = self.conv_first_hf(x_before) - x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf - x_hf = self.conv_before_upsample_hf(x_hf) - x_hf = self.conv_last_hf(self.upsample_hf(x_hf)) - x = x_out + x_hf - x_hf = x_hf / self.img_range + self.mean - - elif self.upsampler == 'pixelshuffledirect': - # for lightweight SR - x = self.conv_first(x) - x = self.conv_after_body(self.forward_features(x)) + x - x = self.upsample(x) - elif self.upsampler == 'nearest+conv': - # for real-world SR - x = self.conv_first(x) - x = self.conv_after_body(self.forward_features(x)) + x - x = self.conv_before_upsample(x) - x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) - x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) - x = self.conv_last(self.lrelu(self.conv_hr(x))) - else: - # for image denoising and JPEG compression artifact reduction - x_first = self.conv_first(x) - res = self.conv_after_body(self.forward_features(x_first)) + x_first - x = x + self.conv_last(res) - - x = x / self.img_range + self.mean - if self.upsampler == "pixelshuffle_aux": - return x[:, :, :H*self.upscale, :W*self.upscale], aux - - elif self.upsampler == "pixelshuffle_hf": - x_out = x_out / self.img_range + self.mean - return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale] - - else: - return x[:, :, :H*self.upscale, :W*self.upscale] - - def flops(self): - flops = 0 - H, W = self.patches_resolution - flops += H * W * 3 * self.embed_dim * 9 - flops += self.patch_embed.flops() - for layer in self.layers: - flops += layer.flops() - flops += H * W * 3 * self.embed_dim * self.embed_dim - flops += self.upsample.flops() - return flops - - -if __name__ == '__main__': - upscale = 4 - window_size = 8 - height = (1024 // upscale // window_size + 1) * window_size - width = (720 // upscale // window_size + 1) * window_size - model = Swin2SR(upscale=2, img_size=(height, width), - window_size=window_size, img_range=1., depths=[6, 6, 6, 6], - embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') - print(model) - print(height, width, model.flops() / 1e9) - - x = torch.randn((1, 3, height, width)) - x = model(x) - print(x.shape) diff --git a/extensions-builtin/extra-options-section/scripts/extra_options_section.py b/extensions-builtin/extra-options-section/scripts/extra_options_section.py index 983f87ff0..8aa901fd4 100644 --- a/extensions-builtin/extra-options-section/scripts/extra_options_section.py +++ b/extensions-builtin/extra-options-section/scripts/extra_options_section.py @@ -1,7 +1,7 @@ import math import gradio as gr -from modules import scripts, shared, ui_components, ui_settings, generation_parameters_copypaste +from modules import scripts, shared, ui_components, ui_settings, infotext from modules.ui_components import FormColumn @@ -23,11 +23,12 @@ class ExtraOptionsSection(scripts.Script): self.setting_names = [] self.infotext_fields = [] extra_options = shared.opts.extra_options_img2img if is_img2img else shared.opts.extra_options_txt2img + elem_id_tabname = "extra_options_" + ("img2img" if is_img2img else "txt2img") - mapping = {k: v for v, k in generation_parameters_copypaste.infotext_to_setting_name_mapping} + mapping = {k: v for v, k in infotext.infotext_to_setting_name_mapping} with gr.Blocks() as interface: - with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and extra_options else gr.Group(): + with gr.Accordion("Options", open=False, elem_id=elem_id_tabname) if shared.opts.extra_options_accordion and extra_options else gr.Group(elem_id=elem_id_tabname): row_count = math.ceil(len(extra_options) / shared.opts.extra_options_cols) @@ -64,11 +65,14 @@ class ExtraOptionsSection(scripts.Script): p.override_settings[name] = value -shared.options_templates.update(shared.options_section(('ui', "User interface"), { - "extra_options_txt2img": shared.OptionInfo([], "Options in main UI - txt2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img interfaces").needs_reload_ui(), - "extra_options_img2img": shared.OptionInfo([], "Options in main UI - img2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in img2img interfaces").needs_reload_ui(), - "extra_options_cols": shared.OptionInfo(1, "Options in main UI - number of columns", gr.Number, {"precision": 0}).needs_reload_ui(), - "extra_options_accordion": shared.OptionInfo(False, "Options in main UI - place into an accordion").needs_reload_ui() +shared.options_templates.update(shared.options_section(('settings_in_ui', "Settings in UI", "ui"), { + "settings_in_ui": shared.OptionHTML(""" +This page allows you to add some settings to the main interface of txt2img and img2img tabs. +"""), + "extra_options_txt2img": shared.OptionInfo([], "Settings for txt2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img interfaces").needs_reload_ui(), + "extra_options_img2img": shared.OptionInfo([], "Settings for img2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in img2img interfaces").needs_reload_ui(), + "extra_options_cols": shared.OptionInfo(1, "Number of columns for added settings", gr.Slider, {"step": 1, "minimum": 1, "maximum": 20}).info("displayed amount will depend on the actual browser window width").needs_reload_ui(), + "extra_options_accordion": shared.OptionInfo(False, "Place added settings into an accordion").needs_reload_ui() })) diff --git a/extensions-builtin/hypertile/hypertile.py b/extensions-builtin/hypertile/hypertile.py new file mode 100644 index 000000000..0f40e2d39 --- /dev/null +++ b/extensions-builtin/hypertile/hypertile.py @@ -0,0 +1,351 @@ +""" +Hypertile module for splitting attention layers in SD-1.5 U-Net and SD-1.5 VAE +Warn: The patch works well only if the input image has a width and height that are multiples of 128 +Original author: @tfernd Github: https://github.com/tfernd/HyperTile +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Callable + +from functools import wraps, cache + +import math +import torch.nn as nn +import random + +from einops import rearrange + + +@dataclass +class HypertileParams: + depth = 0 + layer_name = "" + tile_size: int = 0 + swap_size: int = 0 + aspect_ratio: float = 1.0 + forward = None + enabled = False + + + +# TODO add SD-XL layers +DEPTH_LAYERS = { + 0: [ + # SD 1.5 U-Net (diffusers) + "down_blocks.0.attentions.0.transformer_blocks.0.attn1", + "down_blocks.0.attentions.1.transformer_blocks.0.attn1", + "up_blocks.3.attentions.0.transformer_blocks.0.attn1", + "up_blocks.3.attentions.1.transformer_blocks.0.attn1", + "up_blocks.3.attentions.2.transformer_blocks.0.attn1", + # SD 1.5 U-Net (ldm) + "input_blocks.1.1.transformer_blocks.0.attn1", + "input_blocks.2.1.transformer_blocks.0.attn1", + "output_blocks.9.1.transformer_blocks.0.attn1", + "output_blocks.10.1.transformer_blocks.0.attn1", + "output_blocks.11.1.transformer_blocks.0.attn1", + # SD 1.5 VAE + "decoder.mid_block.attentions.0", + "decoder.mid.attn_1", + ], + 1: [ + # SD 1.5 U-Net (diffusers) + "down_blocks.1.attentions.0.transformer_blocks.0.attn1", + "down_blocks.1.attentions.1.transformer_blocks.0.attn1", + "up_blocks.2.attentions.0.transformer_blocks.0.attn1", + "up_blocks.2.attentions.1.transformer_blocks.0.attn1", + "up_blocks.2.attentions.2.transformer_blocks.0.attn1", + # SD 1.5 U-Net (ldm) + "input_blocks.4.1.transformer_blocks.0.attn1", + "input_blocks.5.1.transformer_blocks.0.attn1", + "output_blocks.6.1.transformer_blocks.0.attn1", + "output_blocks.7.1.transformer_blocks.0.attn1", + "output_blocks.8.1.transformer_blocks.0.attn1", + ], + 2: [ + # SD 1.5 U-Net (diffusers) + "down_blocks.2.attentions.0.transformer_blocks.0.attn1", + "down_blocks.2.attentions.1.transformer_blocks.0.attn1", + "up_blocks.1.attentions.0.transformer_blocks.0.attn1", + "up_blocks.1.attentions.1.transformer_blocks.0.attn1", + "up_blocks.1.attentions.2.transformer_blocks.0.attn1", + # SD 1.5 U-Net (ldm) + "input_blocks.7.1.transformer_blocks.0.attn1", + "input_blocks.8.1.transformer_blocks.0.attn1", + "output_blocks.3.1.transformer_blocks.0.attn1", + "output_blocks.4.1.transformer_blocks.0.attn1", + "output_blocks.5.1.transformer_blocks.0.attn1", + ], + 3: [ + # SD 1.5 U-Net (diffusers) + "mid_block.attentions.0.transformer_blocks.0.attn1", + # SD 1.5 U-Net (ldm) + "middle_block.1.transformer_blocks.0.attn1", + ], +} +# XL layers, thanks for GitHub@gel-crabs for the help +DEPTH_LAYERS_XL = { + 0: [ + # SD 1.5 U-Net (diffusers) + "down_blocks.0.attentions.0.transformer_blocks.0.attn1", + "down_blocks.0.attentions.1.transformer_blocks.0.attn1", + "up_blocks.3.attentions.0.transformer_blocks.0.attn1", + "up_blocks.3.attentions.1.transformer_blocks.0.attn1", + "up_blocks.3.attentions.2.transformer_blocks.0.attn1", + # SD 1.5 U-Net (ldm) + "input_blocks.4.1.transformer_blocks.0.attn1", + "input_blocks.5.1.transformer_blocks.0.attn1", + "output_blocks.3.1.transformer_blocks.0.attn1", + "output_blocks.4.1.transformer_blocks.0.attn1", + "output_blocks.5.1.transformer_blocks.0.attn1", + # SD 1.5 VAE + "decoder.mid_block.attentions.0", + "decoder.mid.attn_1", + ], + 1: [ + # SD 1.5 U-Net (diffusers) + #"down_blocks.1.attentions.0.transformer_blocks.0.attn1", + #"down_blocks.1.attentions.1.transformer_blocks.0.attn1", + #"up_blocks.2.attentions.0.transformer_blocks.0.attn1", + #"up_blocks.2.attentions.1.transformer_blocks.0.attn1", + #"up_blocks.2.attentions.2.transformer_blocks.0.attn1", + # SD 1.5 U-Net (ldm) + "input_blocks.4.1.transformer_blocks.1.attn1", + "input_blocks.5.1.transformer_blocks.1.attn1", + "output_blocks.3.1.transformer_blocks.1.attn1", + "output_blocks.4.1.transformer_blocks.1.attn1", + "output_blocks.5.1.transformer_blocks.1.attn1", + "input_blocks.7.1.transformer_blocks.0.attn1", + "input_blocks.8.1.transformer_blocks.0.attn1", + "output_blocks.0.1.transformer_blocks.0.attn1", + "output_blocks.1.1.transformer_blocks.0.attn1", + "output_blocks.2.1.transformer_blocks.0.attn1", + "input_blocks.7.1.transformer_blocks.1.attn1", + "input_blocks.8.1.transformer_blocks.1.attn1", + "output_blocks.0.1.transformer_blocks.1.attn1", + "output_blocks.1.1.transformer_blocks.1.attn1", + "output_blocks.2.1.transformer_blocks.1.attn1", + "input_blocks.7.1.transformer_blocks.2.attn1", + "input_blocks.8.1.transformer_blocks.2.attn1", + "output_blocks.0.1.transformer_blocks.2.attn1", + "output_blocks.1.1.transformer_blocks.2.attn1", + "output_blocks.2.1.transformer_blocks.2.attn1", + "input_blocks.7.1.transformer_blocks.3.attn1", + "input_blocks.8.1.transformer_blocks.3.attn1", + "output_blocks.0.1.transformer_blocks.3.attn1", + "output_blocks.1.1.transformer_blocks.3.attn1", + "output_blocks.2.1.transformer_blocks.3.attn1", + "input_blocks.7.1.transformer_blocks.4.attn1", + "input_blocks.8.1.transformer_blocks.4.attn1", + "output_blocks.0.1.transformer_blocks.4.attn1", + "output_blocks.1.1.transformer_blocks.4.attn1", + "output_blocks.2.1.transformer_blocks.4.attn1", + "input_blocks.7.1.transformer_blocks.5.attn1", + "input_blocks.8.1.transformer_blocks.5.attn1", + "output_blocks.0.1.transformer_blocks.5.attn1", + "output_blocks.1.1.transformer_blocks.5.attn1", + "output_blocks.2.1.transformer_blocks.5.attn1", + "input_blocks.7.1.transformer_blocks.6.attn1", + "input_blocks.8.1.transformer_blocks.6.attn1", + "output_blocks.0.1.transformer_blocks.6.attn1", + "output_blocks.1.1.transformer_blocks.6.attn1", + "output_blocks.2.1.transformer_blocks.6.attn1", + "input_blocks.7.1.transformer_blocks.7.attn1", + "input_blocks.8.1.transformer_blocks.7.attn1", + "output_blocks.0.1.transformer_blocks.7.attn1", + "output_blocks.1.1.transformer_blocks.7.attn1", + "output_blocks.2.1.transformer_blocks.7.attn1", + "input_blocks.7.1.transformer_blocks.8.attn1", + "input_blocks.8.1.transformer_blocks.8.attn1", + "output_blocks.0.1.transformer_blocks.8.attn1", + "output_blocks.1.1.transformer_blocks.8.attn1", + "output_blocks.2.1.transformer_blocks.8.attn1", + "input_blocks.7.1.transformer_blocks.9.attn1", + "input_blocks.8.1.transformer_blocks.9.attn1", + "output_blocks.0.1.transformer_blocks.9.attn1", + "output_blocks.1.1.transformer_blocks.9.attn1", + "output_blocks.2.1.transformer_blocks.9.attn1", + ], + 2: [ + # SD 1.5 U-Net (diffusers) + "mid_block.attentions.0.transformer_blocks.0.attn1", + # SD 1.5 U-Net (ldm) + "middle_block.1.transformer_blocks.0.attn1", + "middle_block.1.transformer_blocks.1.attn1", + "middle_block.1.transformer_blocks.2.attn1", + "middle_block.1.transformer_blocks.3.attn1", + "middle_block.1.transformer_blocks.4.attn1", + "middle_block.1.transformer_blocks.5.attn1", + "middle_block.1.transformer_blocks.6.attn1", + "middle_block.1.transformer_blocks.7.attn1", + "middle_block.1.transformer_blocks.8.attn1", + "middle_block.1.transformer_blocks.9.attn1", + ], + 3 : [] # TODO - separate layers for SD-XL +} + + +RNG_INSTANCE = random.Random() + +@cache +def get_divisors(value: int, min_value: int, /, max_options: int = 1) -> list[int]: + """ + Returns divisors of value that + x * min_value <= value + in big -> small order, amount of divisors is limited by max_options + """ + max_options = max(1, max_options) # at least 1 option should be returned + min_value = min(min_value, value) + divisors = [i for i in range(min_value, value + 1) if value % i == 0] # divisors in small -> big order + ns = [value // i for i in divisors[:max_options]] # has at least 1 element # big -> small order + return ns + + +def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int: + """ + Returns a random divisor of value that + x * min_value <= value + if max_options is 1, the behavior is deterministic + """ + ns = get_divisors(value, min_value, max_options=max_options) # get cached divisors + idx = RNG_INSTANCE.randint(0, len(ns) - 1) + + return ns[idx] + + +def set_hypertile_seed(seed: int) -> None: + RNG_INSTANCE.seed(seed) + + +@cache +def largest_tile_size_available(width: int, height: int) -> int: + """ + Calculates the largest tile size available for a given width and height + Tile size is always a power of 2 + """ + gcd = math.gcd(width, height) + largest_tile_size_available = 1 + while gcd % (largest_tile_size_available * 2) == 0: + largest_tile_size_available *= 2 + return largest_tile_size_available + + +def iterative_closest_divisors(hw:int, aspect_ratio:float) -> tuple[int, int]: + """ + Finds h and w such that h*w = hw and h/w = aspect_ratio + We check all possible divisors of hw and return the closest to the aspect ratio + """ + divisors = [i for i in range(2, hw + 1) if hw % i == 0] # all divisors of hw + pairs = [(i, hw // i) for i in divisors] # all pairs of divisors of hw + ratios = [w/h for h, w in pairs] # all ratios of pairs of divisors of hw + closest_ratio = min(ratios, key=lambda x: abs(x - aspect_ratio)) # closest ratio to aspect_ratio + closest_pair = pairs[ratios.index(closest_ratio)] # closest pair of divisors to aspect_ratio + return closest_pair + + +@cache +def find_hw_candidates(hw:int, aspect_ratio:float) -> tuple[int, int]: + """ + Finds h and w such that h*w = hw and h/w = aspect_ratio + """ + h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio)) + # find h and w such that h*w = hw and h/w = aspect_ratio + if h * w != hw: + w_candidate = hw / h + # check if w is an integer + if not w_candidate.is_integer(): + h_candidate = hw / w + # check if h is an integer + if not h_candidate.is_integer(): + return iterative_closest_divisors(hw, aspect_ratio) + else: + h = int(h_candidate) + else: + w = int(w_candidate) + return h, w + + +def self_attn_forward(params: HypertileParams, scale_depth=True) -> Callable: + + @wraps(params.forward) + def wrapper(*args, **kwargs): + if not params.enabled: + return params.forward(*args, **kwargs) + + latent_tile_size = max(128, params.tile_size) // 8 + x = args[0] + + # VAE + if x.ndim == 4: + b, c, h, w = x.shape + + nh = random_divisor(h, latent_tile_size, params.swap_size) + nw = random_divisor(w, latent_tile_size, params.swap_size) + + if nh * nw > 1: + x = rearrange(x, "b c (nh h) (nw w) -> (b nh nw) c h w", nh=nh, nw=nw) # split into nh * nw tiles + + out = params.forward(x, *args[1:], **kwargs) + + if nh * nw > 1: + out = rearrange(out, "(b nh nw) c h w -> b c (nh h) (nw w)", nh=nh, nw=nw) + + # U-Net + else: + hw: int = x.size(1) + h, w = find_hw_candidates(hw, params.aspect_ratio) + assert h * w == hw, f"Invalid aspect ratio {params.aspect_ratio} for input of shape {x.shape}, hw={hw}, h={h}, w={w}" + + factor = 2 ** params.depth if scale_depth else 1 + nh = random_divisor(h, latent_tile_size * factor, params.swap_size) + nw = random_divisor(w, latent_tile_size * factor, params.swap_size) + + if nh * nw > 1: + x = rearrange(x, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw) + + out = params.forward(x, *args[1:], **kwargs) + + if nh * nw > 1: + out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw) + out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw) + + return out + + return wrapper + + +def hypertile_hook_model(model: nn.Module, width, height, *, enable=False, tile_size_max=128, swap_size=1, max_depth=3, is_sdxl=False): + hypertile_layers = getattr(model, "__webui_hypertile_layers", None) + if hypertile_layers is None: + if not enable: + return + + hypertile_layers = {} + layers = DEPTH_LAYERS_XL if is_sdxl else DEPTH_LAYERS + + for depth in range(4): + for layer_name, module in model.named_modules(): + if any(layer_name.endswith(try_name) for try_name in layers[depth]): + params = HypertileParams() + module.__webui_hypertile_params = params + params.forward = module.forward + params.depth = depth + params.layer_name = layer_name + module.forward = self_attn_forward(params) + + hypertile_layers[layer_name] = 1 + + model.__webui_hypertile_layers = hypertile_layers + + aspect_ratio = width / height + tile_size = min(largest_tile_size_available(width, height), tile_size_max) + + for layer_name, module in model.named_modules(): + if layer_name in hypertile_layers: + params = module.__webui_hypertile_params + + params.tile_size = tile_size + params.swap_size = swap_size + params.aspect_ratio = aspect_ratio + params.enabled = enable and params.depth <= max_depth diff --git a/extensions-builtin/hypertile/scripts/hypertile_script.py b/extensions-builtin/hypertile/scripts/hypertile_script.py new file mode 100644 index 000000000..395d584b6 --- /dev/null +++ b/extensions-builtin/hypertile/scripts/hypertile_script.py @@ -0,0 +1,109 @@ +import hypertile +from modules import scripts, script_callbacks, shared +from scripts.hypertile_xyz import add_axis_options + + +class ScriptHypertile(scripts.Script): + name = "Hypertile" + + def title(self): + return self.name + + def show(self, is_img2img): + return scripts.AlwaysVisible + + def process(self, p, *args): + hypertile.set_hypertile_seed(p.all_seeds[0]) + + configure_hypertile(p.width, p.height, enable_unet=shared.opts.hypertile_enable_unet) + + self.add_infotext(p) + + def before_hr(self, p, *args): + + enable = shared.opts.hypertile_enable_unet_secondpass or shared.opts.hypertile_enable_unet + + # exclusive hypertile seed for the second pass + if enable: + hypertile.set_hypertile_seed(p.all_seeds[0]) + + configure_hypertile(p.hr_upscale_to_x, p.hr_upscale_to_y, enable_unet=enable) + + if enable and not shared.opts.hypertile_enable_unet: + p.extra_generation_params["Hypertile U-Net second pass"] = True + + self.add_infotext(p, add_unet_params=True) + + def add_infotext(self, p, add_unet_params=False): + def option(name): + value = getattr(shared.opts, name) + default_value = shared.opts.get_default(name) + return None if value == default_value else value + + if shared.opts.hypertile_enable_unet: + p.extra_generation_params["Hypertile U-Net"] = True + + if shared.opts.hypertile_enable_unet or add_unet_params: + p.extra_generation_params["Hypertile U-Net max depth"] = option('hypertile_max_depth_unet') + p.extra_generation_params["Hypertile U-Net max tile size"] = option('hypertile_max_tile_unet') + p.extra_generation_params["Hypertile U-Net swap size"] = option('hypertile_swap_size_unet') + + if shared.opts.hypertile_enable_vae: + p.extra_generation_params["Hypertile VAE"] = True + p.extra_generation_params["Hypertile VAE max depth"] = option('hypertile_max_depth_vae') + p.extra_generation_params["Hypertile VAE max tile size"] = option('hypertile_max_tile_vae') + p.extra_generation_params["Hypertile VAE swap size"] = option('hypertile_swap_size_vae') + + +def configure_hypertile(width, height, enable_unet=True): + hypertile.hypertile_hook_model( + shared.sd_model.first_stage_model, + width, + height, + swap_size=shared.opts.hypertile_swap_size_vae, + max_depth=shared.opts.hypertile_max_depth_vae, + tile_size_max=shared.opts.hypertile_max_tile_vae, + enable=shared.opts.hypertile_enable_vae, + ) + + hypertile.hypertile_hook_model( + shared.sd_model.model, + width, + height, + swap_size=shared.opts.hypertile_swap_size_unet, + max_depth=shared.opts.hypertile_max_depth_unet, + tile_size_max=shared.opts.hypertile_max_tile_unet, + enable=enable_unet, + is_sdxl=shared.sd_model.is_sdxl + ) + + +def on_ui_settings(): + import gradio as gr + + options = { + "hypertile_explanation": shared.OptionHTML(""" + Hypertile optimizes the self-attention layer within U-Net and VAE models, + resulting in a reduction in computation time ranging from 1 to 4 times. The larger the generated image is, the greater the + benefit. + """), + + "hypertile_enable_unet": shared.OptionInfo(False, "Enable Hypertile U-Net", infotext="Hypertile U-Net").info("enables hypertile for all modes, including hires fix second pass; noticeable change in details of the generated picture"), + "hypertile_enable_unet_secondpass": shared.OptionInfo(False, "Enable Hypertile U-Net for hires fix second pass", infotext="Hypertile U-Net second pass").info("enables hypertile just for hires fix second pass - regardless of whether the above setting is enabled"), + "hypertile_max_depth_unet": shared.OptionInfo(3, "Hypertile U-Net max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}, infotext="Hypertile U-Net max depth").info("larger = more neural network layers affected; minor effect on performance"), + "hypertile_max_tile_unet": shared.OptionInfo(256, "Hypertile U-Net max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, infotext="Hypertile U-Net max tile size").info("larger = worse performance"), + "hypertile_swap_size_unet": shared.OptionInfo(3, "Hypertile U-Net swap size", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, infotext="Hypertile U-Net swap size"), + + "hypertile_enable_vae": shared.OptionInfo(False, "Enable Hypertile VAE", infotext="Hypertile VAE").info("minimal change in the generated picture"), + "hypertile_max_depth_vae": shared.OptionInfo(3, "Hypertile VAE max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}, infotext="Hypertile VAE max depth"), + "hypertile_max_tile_vae": shared.OptionInfo(128, "Hypertile VAE max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, infotext="Hypertile VAE max tile size"), + "hypertile_swap_size_vae": shared.OptionInfo(3, "Hypertile VAE swap size ", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, infotext="Hypertile VAE swap size"), + } + + for name, opt in options.items(): + opt.section = ('hypertile', "Hypertile") + shared.opts.add_option(name, opt) + + +script_callbacks.on_ui_settings(on_ui_settings) +script_callbacks.on_before_ui(add_axis_options) diff --git a/extensions-builtin/hypertile/scripts/hypertile_xyz.py b/extensions-builtin/hypertile/scripts/hypertile_xyz.py new file mode 100644 index 000000000..9e96ae3c5 --- /dev/null +++ b/extensions-builtin/hypertile/scripts/hypertile_xyz.py @@ -0,0 +1,51 @@ +from modules import scripts +from modules.shared import opts + +xyz_grid = [x for x in scripts.scripts_data if x.script_class.__module__ == "xyz_grid.py"][0].module + +def int_applier(value_name:str, min_range:int = -1, max_range:int = -1): + """ + Returns a function that applies the given value to the given value_name in opts.data. + """ + def validate(value_name:str, value:str): + value = int(value) + # validate value + if not min_range == -1: + assert value >= min_range, f"Value {value} for {value_name} must be greater than or equal to {min_range}" + if not max_range == -1: + assert value <= max_range, f"Value {value} for {value_name} must be less than or equal to {max_range}" + def apply_int(p, x, xs): + validate(value_name, x) + opts.data[value_name] = int(x) + return apply_int + +def bool_applier(value_name:str): + """ + Returns a function that applies the given value to the given value_name in opts.data. + """ + def validate(value_name:str, value:str): + assert value.lower() in ["true", "false"], f"Value {value} for {value_name} must be either true or false" + def apply_bool(p, x, xs): + validate(value_name, x) + value_boolean = x.lower() == "true" + opts.data[value_name] = value_boolean + return apply_bool + +def add_axis_options(): + extra_axis_options = [ + xyz_grid.AxisOption("[Hypertile] Unet First pass Enabled", str, bool_applier("hypertile_enable_unet"), choices=xyz_grid.boolean_choice(reverse=True)), + xyz_grid.AxisOption("[Hypertile] Unet Second pass Enabled", str, bool_applier("hypertile_enable_unet_secondpass"), choices=xyz_grid.boolean_choice(reverse=True)), + xyz_grid.AxisOption("[Hypertile] Unet Max Depth", int, int_applier("hypertile_max_depth_unet", 0, 3), choices=lambda: [str(x) for x in range(4)]), + xyz_grid.AxisOption("[Hypertile] Unet Max Tile Size", int, int_applier("hypertile_max_tile_unet", 0, 512)), + xyz_grid.AxisOption("[Hypertile] Unet Swap Size", int, int_applier("hypertile_swap_size_unet", 0, 64)), + xyz_grid.AxisOption("[Hypertile] VAE Enabled", str, bool_applier("hypertile_enable_vae"), choices=xyz_grid.boolean_choice(reverse=True)), + xyz_grid.AxisOption("[Hypertile] VAE Max Depth", int, int_applier("hypertile_max_depth_vae", 0, 3), choices=lambda: [str(x) for x in range(4)]), + xyz_grid.AxisOption("[Hypertile] VAE Max Tile Size", int, int_applier("hypertile_max_tile_vae", 0, 512)), + xyz_grid.AxisOption("[Hypertile] VAE Swap Size", int, int_applier("hypertile_swap_size_vae", 0, 64)), + ] + set_a = {opt.label for opt in xyz_grid.axis_options} + set_b = {opt.label for opt in extra_axis_options} + if set_a.intersection(set_b): + return + + xyz_grid.axis_options.extend(extra_axis_options) diff --git a/extensions-builtin/mobile/javascript/mobile.js b/extensions-builtin/mobile/javascript/mobile.js index 652f07ac7..bff1acedf 100644 --- a/extensions-builtin/mobile/javascript/mobile.js +++ b/extensions-builtin/mobile/javascript/mobile.js @@ -12,6 +12,8 @@ function isMobile() { } function reportWindowSize() { + if (gradioApp().querySelector('.toprow-compact-tools')) return; // not applicable for compact prompt layout + var currentlyMobile = isMobile(); if (currentlyMobile == isSetupForMobile) return; isSetupForMobile = currentlyMobile; diff --git a/extensions-builtin/soft-inpainting/scripts/soft_inpainting.py b/extensions-builtin/soft-inpainting/scripts/soft_inpainting.py new file mode 100644 index 000000000..d90243442 --- /dev/null +++ b/extensions-builtin/soft-inpainting/scripts/soft_inpainting.py @@ -0,0 +1,747 @@ +import numpy as np +import gradio as gr +import math +from modules.ui_components import InputAccordion +import modules.scripts as scripts + + +class SoftInpaintingSettings: + def __init__(self, + mask_blend_power, + mask_blend_scale, + inpaint_detail_preservation, + composite_mask_influence, + composite_difference_threshold, + composite_difference_contrast): + self.mask_blend_power = mask_blend_power + self.mask_blend_scale = mask_blend_scale + self.inpaint_detail_preservation = inpaint_detail_preservation + self.composite_mask_influence = composite_mask_influence + self.composite_difference_threshold = composite_difference_threshold + self.composite_difference_contrast = composite_difference_contrast + + def add_generation_params(self, dest): + dest[enabled_gen_param_label] = True + dest[gen_param_labels.mask_blend_power] = self.mask_blend_power + dest[gen_param_labels.mask_blend_scale] = self.mask_blend_scale + dest[gen_param_labels.inpaint_detail_preservation] = self.inpaint_detail_preservation + dest[gen_param_labels.composite_mask_influence] = self.composite_mask_influence + dest[gen_param_labels.composite_difference_threshold] = self.composite_difference_threshold + dest[gen_param_labels.composite_difference_contrast] = self.composite_difference_contrast + + +# ------------------- Methods ------------------- + +def processing_uses_inpainting(p): + # TODO: Figure out a better way to determine if inpainting is being used by p + if getattr(p, "image_mask", None) is not None: + return True + + if getattr(p, "mask", None) is not None: + return True + + if getattr(p, "nmask", None) is not None: + return True + + return False + + +def latent_blend(settings, a, b, t): + """ + Interpolates two latent image representations according to the parameter t, + where the interpolated vectors' magnitudes are also interpolated separately. + The "detail_preservation" factor biases the magnitude interpolation towards + the larger of the two magnitudes. + """ + import torch + + # NOTE: We use inplace operations wherever possible. + + # [4][w][h] to [1][4][w][h] + t2 = t.unsqueeze(0) + # [4][w][h] to [1][1][w][h] - the [4] seem redundant. + t3 = t[0].unsqueeze(0).unsqueeze(0) + + one_minus_t2 = 1 - t2 + one_minus_t3 = 1 - t3 + + # Linearly interpolate the image vectors. + a_scaled = a * one_minus_t2 + b_scaled = b * t2 + image_interp = a_scaled + image_interp.add_(b_scaled) + result_type = image_interp.dtype + del a_scaled, b_scaled, t2, one_minus_t2 + + # Calculate the magnitude of the interpolated vectors. (We will remove this magnitude.) + # 64-bit operations are used here to allow large exponents. + current_magnitude = torch.norm(image_interp, p=2, dim=1, keepdim=True).to(torch.float64).add_(0.00001) + + # Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1). + a_magnitude = torch.norm(a, p=2, dim=1, keepdim=True).to(torch.float64).pow_( + settings.inpaint_detail_preservation) * one_minus_t3 + b_magnitude = torch.norm(b, p=2, dim=1, keepdim=True).to(torch.float64).pow_( + settings.inpaint_detail_preservation) * t3 + desired_magnitude = a_magnitude + desired_magnitude.add_(b_magnitude).pow_(1 / settings.inpaint_detail_preservation) + del a_magnitude, b_magnitude, t3, one_minus_t3 + + # Change the linearly interpolated image vectors' magnitudes to the value we want. + # This is the last 64-bit operation. + image_interp_scaling_factor = desired_magnitude + image_interp_scaling_factor.div_(current_magnitude) + image_interp_scaling_factor = image_interp_scaling_factor.to(result_type) + image_interp_scaled = image_interp + image_interp_scaled.mul_(image_interp_scaling_factor) + del current_magnitude + del desired_magnitude + del image_interp + del image_interp_scaling_factor + del result_type + + return image_interp_scaled + + +def get_modified_nmask(settings, nmask, sigma): + """ + Converts a negative mask representing the transparency of the original latent vectors being overlayed + to a mask that is scaled according to the denoising strength for this step. + + Where: + 0 = fully opaque, infinite density, fully masked + 1 = fully transparent, zero density, fully unmasked + + We bring this transparency to a power, as this allows one to simulate N number of blending operations + where N can be any positive real value. Using this one can control the balance of influence between + the denoiser and the original latents according to the sigma value. + + NOTE: "mask" is not used + """ + import torch + return torch.pow(nmask, (sigma ** settings.mask_blend_power) * settings.mask_blend_scale) + + +def apply_adaptive_masks( + settings: SoftInpaintingSettings, + nmask, + latent_orig, + latent_processed, + overlay_images, + width, height, + paste_to): + import torch + import modules.processing as proc + import modules.images as images + from PIL import Image, ImageOps, ImageFilter + + # TODO: Bias the blending according to the latent mask, add adjustable parameter for bias control. + latent_mask = nmask[0].float() + # convert the original mask into a form we use to scale distances for thresholding + mask_scalar = 1 - (torch.clamp(latent_mask, min=0, max=1) ** (settings.mask_blend_scale / 2)) + mask_scalar = (0.5 * (1 - settings.composite_mask_influence) + + mask_scalar * settings.composite_mask_influence) + mask_scalar = mask_scalar / (1.00001 - mask_scalar) + mask_scalar = mask_scalar.cpu().numpy() + + latent_distance = torch.norm(latent_processed - latent_orig, p=2, dim=1) + + kernel, kernel_center = get_gaussian_kernel(stddev_radius=1.5, max_radius=2) + + masks_for_overlay = [] + + for i, (distance_map, overlay_image) in enumerate(zip(latent_distance, overlay_images)): + converted_mask = distance_map.float().cpu().numpy() + converted_mask = weighted_histogram_filter(converted_mask, kernel, kernel_center, + percentile_min=0.9, percentile_max=1, min_width=1) + converted_mask = weighted_histogram_filter(converted_mask, kernel, kernel_center, + percentile_min=0.25, percentile_max=0.75, min_width=1) + + # The distance at which opacity of original decreases to 50% + half_weighted_distance = settings.composite_difference_threshold * mask_scalar + converted_mask = converted_mask / half_weighted_distance + + converted_mask = 1 / (1 + converted_mask ** settings.composite_difference_contrast) + converted_mask = smootherstep(converted_mask) + converted_mask = 1 - converted_mask + converted_mask = 255. * converted_mask + converted_mask = converted_mask.astype(np.uint8) + converted_mask = Image.fromarray(converted_mask) + converted_mask = images.resize_image(2, converted_mask, width, height) + converted_mask = proc.create_binary_mask(converted_mask, round=False) + + # Remove aliasing artifacts using a gaussian blur. + converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4)) + + # Expand the mask to fit the whole image if needed. + if paste_to is not None: + converted_mask = proc.uncrop(converted_mask, + (overlay_image.width, overlay_image.height), + paste_to) + + masks_for_overlay.append(converted_mask) + + image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height)) + image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"), + mask=ImageOps.invert(converted_mask.convert('L'))) + + overlay_images[i] = image_masked.convert('RGBA') + + return masks_for_overlay + + +def apply_masks( + settings, + nmask, + overlay_images, + width, height, + paste_to): + import torch + import modules.processing as proc + import modules.images as images + from PIL import Image, ImageOps, ImageFilter + + converted_mask = nmask[0].float() + converted_mask = torch.clamp(converted_mask, min=0, max=1).pow_(settings.mask_blend_scale / 2) + converted_mask = 255. * converted_mask + converted_mask = converted_mask.cpu().numpy().astype(np.uint8) + converted_mask = Image.fromarray(converted_mask) + converted_mask = images.resize_image(2, converted_mask, width, height) + converted_mask = proc.create_binary_mask(converted_mask, round=False) + + # Remove aliasing artifacts using a gaussian blur. + converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4)) + + # Expand the mask to fit the whole image if needed. + if paste_to is not None: + converted_mask = proc.uncrop(converted_mask, + (width, height), + paste_to) + + masks_for_overlay = [] + + for i, overlay_image in enumerate(overlay_images): + masks_for_overlay[i] = converted_mask + + image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height)) + image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"), + mask=ImageOps.invert(converted_mask.convert('L'))) + + overlay_images[i] = image_masked.convert('RGBA') + + return masks_for_overlay + + +def weighted_histogram_filter(img, kernel, kernel_center, percentile_min=0.0, percentile_max=1.0, min_width=1.0): + """ + Generalization convolution filter capable of applying + weighted mean, median, maximum, and minimum filters + parametrically using an arbitrary kernel. + + Args: + img (nparray): + The image, a 2-D array of floats, to which the filter is being applied. + kernel (nparray): + The kernel, a 2-D array of floats. + kernel_center (nparray): + The kernel center coordinate, a 1-D array with two elements. + percentile_min (float): + The lower bound of the histogram window used by the filter, + from 0 to 1. + percentile_max (float): + The upper bound of the histogram window used by the filter, + from 0 to 1. + min_width (float): + The minimum size of the histogram window bounds, in weight units. + Must be greater than 0. + + Returns: + (nparray): A filtered copy of the input image "img", a 2-D array of floats. + """ + + # Converts an index tuple into a vector. + def vec(x): + return np.array(x) + + kernel_min = -kernel_center + kernel_max = vec(kernel.shape) - kernel_center + + def weighted_histogram_filter_single(idx): + idx = vec(idx) + min_index = np.maximum(0, idx + kernel_min) + max_index = np.minimum(vec(img.shape), idx + kernel_max) + window_shape = max_index - min_index + + class WeightedElement: + """ + An element of the histogram, its weight + and bounds. + """ + + def __init__(self, value, weight): + self.value: float = value + self.weight: float = weight + self.window_min: float = 0.0 + self.window_max: float = 1.0 + + # Collect the values in the image as WeightedElements, + # weighted by their corresponding kernel values. + values = [] + for window_tup in np.ndindex(tuple(window_shape)): + window_index = vec(window_tup) + image_index = window_index + min_index + centered_kernel_index = image_index - idx + kernel_index = centered_kernel_index + kernel_center + element = WeightedElement(img[tuple(image_index)], kernel[tuple(kernel_index)]) + values.append(element) + + def sort_key(x: WeightedElement): + return x.value + + values.sort(key=sort_key) + + # Calculate the height of the stack (sum) + # and each sample's range they occupy in the stack + sum = 0 + for i in range(len(values)): + values[i].window_min = sum + sum += values[i].weight + values[i].window_max = sum + + # Calculate what range of this stack ("window") + # we want to get the weighted average across. + window_min = sum * percentile_min + window_max = sum * percentile_max + window_width = window_max - window_min + + # Ensure the window is within the stack and at least a certain size. + if window_width < min_width: + window_center = (window_min + window_max) / 2 + window_min = window_center - min_width / 2 + window_max = window_center + min_width / 2 + + if window_max > sum: + window_max = sum + window_min = sum - min_width + + if window_min < 0: + window_min = 0 + window_max = min_width + + value = 0 + value_weight = 0 + + # Get the weighted average of all the samples + # that overlap with the window, weighted + # by the size of their overlap. + for i in range(len(values)): + if window_min >= values[i].window_max: + continue + if window_max <= values[i].window_min: + break + + s = max(window_min, values[i].window_min) + e = min(window_max, values[i].window_max) + w = e - s + + value += values[i].value * w + value_weight += w + + return value / value_weight if value_weight != 0 else 0 + + img_out = img.copy() + + # Apply the kernel operation over each pixel. + for index in np.ndindex(img.shape): + img_out[index] = weighted_histogram_filter_single(index) + + return img_out + + +def smoothstep(x): + """ + The smoothstep function, input should be clamped to 0-1 range. + Turns a diagonal line (f(x) = x) into a sigmoid-like curve. + """ + return x * x * (3 - 2 * x) + + +def smootherstep(x): + """ + The smootherstep function, input should be clamped to 0-1 range. + Turns a diagonal line (f(x) = x) into a sigmoid-like curve. + """ + return x * x * x * (x * (6 * x - 15) + 10) + + +def get_gaussian_kernel(stddev_radius=1.0, max_radius=2): + """ + Creates a Gaussian kernel with thresholded edges. + + Args: + stddev_radius (float): + Standard deviation of the gaussian kernel, in pixels. + max_radius (int): + The size of the filter kernel. The number of pixels is (max_radius*2+1) ** 2. + The kernel is thresholded so that any values one pixel beyond this radius + is weighted at 0. + + Returns: + (nparray, nparray): A kernel array (shape: (N, N)), its center coordinate (shape: (2)) + """ + + # Evaluates a 0-1 normalized gaussian function for a given square distance from the mean. + def gaussian(sqr_mag): + return math.exp(-sqr_mag / (stddev_radius * stddev_radius)) + + # Helper function for converting a tuple to an array. + def vec(x): + return np.array(x) + + """ + Since a gaussian is unbounded, we need to limit ourselves + to a finite range. + We taper the ends off at the end of that range so they equal zero + while preserving the maximum value of 1 at the mean. + """ + zero_radius = max_radius + 1.0 + gauss_zero = gaussian(zero_radius * zero_radius) + gauss_kernel_scale = 1 / (1 - gauss_zero) + + def gaussian_kernel_func(coordinate): + x = coordinate[0] ** 2.0 + coordinate[1] ** 2.0 + x = gaussian(x) + x -= gauss_zero + x *= gauss_kernel_scale + x = max(0.0, x) + return x + + size = max_radius * 2 + 1 + kernel_center = max_radius + kernel = np.zeros((size, size)) + + for index in np.ndindex(kernel.shape): + kernel[index] = gaussian_kernel_func(vec(index) - kernel_center) + + return kernel, kernel_center + + +# ------------------- Constants ------------------- + + +default = SoftInpaintingSettings(1, 0.5, 4, 0, 0.5, 2) + +enabled_ui_label = "Soft inpainting" +enabled_gen_param_label = "Soft inpainting enabled" +enabled_el_id = "soft_inpainting_enabled" + +ui_labels = SoftInpaintingSettings( + "Schedule bias", + "Preservation strength", + "Transition contrast boost", + "Mask influence", + "Difference threshold", + "Difference contrast") + +ui_info = SoftInpaintingSettings( + "Shifts when preservation of original content occurs during denoising.", + "How strongly partially masked content should be preserved.", + "Amplifies the contrast that may be lost in partially masked regions.", + "How strongly the original mask should bias the difference threshold.", + "How much an image region can change before the original pixels are not blended in anymore.", + "How sharp the transition should be between blended and not blended.") + +gen_param_labels = SoftInpaintingSettings( + "Soft inpainting schedule bias", + "Soft inpainting preservation strength", + "Soft inpainting transition contrast boost", + "Soft inpainting mask influence", + "Soft inpainting difference threshold", + "Soft inpainting difference contrast") + +el_ids = SoftInpaintingSettings( + "mask_blend_power", + "mask_blend_scale", + "inpaint_detail_preservation", + "composite_mask_influence", + "composite_difference_threshold", + "composite_difference_contrast") + + +# ------------------- Script ------------------- + + +class Script(scripts.Script): + def __init__(self): + self.section = "inpaint" + self.masks_for_overlay = None + self.overlay_images = None + + def title(self): + return "Soft Inpainting" + + def show(self, is_img2img): + return scripts.AlwaysVisible if is_img2img else False + + def ui(self, is_img2img): + if not is_img2img: + return + + with InputAccordion(False, label=enabled_ui_label, elem_id=enabled_el_id) as soft_inpainting_enabled: + with gr.Group(): + gr.Markdown( + """ + Soft inpainting allows you to **seamlessly blend original content with inpainted content** according to the mask opacity. + **High _Mask blur_** values are recommended! + """) + + power = \ + gr.Slider(label=ui_labels.mask_blend_power, + info=ui_info.mask_blend_power, + minimum=0, + maximum=8, + step=0.1, + value=default.mask_blend_power, + elem_id=el_ids.mask_blend_power) + scale = \ + gr.Slider(label=ui_labels.mask_blend_scale, + info=ui_info.mask_blend_scale, + minimum=0, + maximum=8, + step=0.05, + value=default.mask_blend_scale, + elem_id=el_ids.mask_blend_scale) + detail = \ + gr.Slider(label=ui_labels.inpaint_detail_preservation, + info=ui_info.inpaint_detail_preservation, + minimum=1, + maximum=32, + step=0.5, + value=default.inpaint_detail_preservation, + elem_id=el_ids.inpaint_detail_preservation) + + gr.Markdown( + """ + ### Pixel Composite Settings + """) + + mask_inf = \ + gr.Slider(label=ui_labels.composite_mask_influence, + info=ui_info.composite_mask_influence, + minimum=0, + maximum=1, + step=0.05, + value=default.composite_mask_influence, + elem_id=el_ids.composite_mask_influence) + + dif_thresh = \ + gr.Slider(label=ui_labels.composite_difference_threshold, + info=ui_info.composite_difference_threshold, + minimum=0, + maximum=8, + step=0.25, + value=default.composite_difference_threshold, + elem_id=el_ids.composite_difference_threshold) + + dif_contr = \ + gr.Slider(label=ui_labels.composite_difference_contrast, + info=ui_info.composite_difference_contrast, + minimum=0, + maximum=8, + step=0.25, + value=default.composite_difference_contrast, + elem_id=el_ids.composite_difference_contrast) + + with gr.Accordion("Help", open=False): + gr.Markdown( + f""" + ### {ui_labels.mask_blend_power} + + The blending strength of original content is scaled proportionally with the decreasing noise level values at each step (sigmas). + This ensures that the influence of the denoiser and original content preservation is roughly balanced at each step. + This balance can be shifted using this parameter, controlling whether earlier or later steps have stronger preservation. + + - **Below 1**: Stronger preservation near the end (with low sigma) + - **1**: Balanced (proportional to sigma) + - **Above 1**: Stronger preservation in the beginning (with high sigma) + """) + gr.Markdown( + f""" + ### {ui_labels.mask_blend_scale} + + Skews whether partially masked image regions should be more likely to preserve the original content or favor inpainted content. + This may need to be adjusted depending on the {ui_labels.mask_blend_power}, CFG Scale, prompt and Denoising strength. + + - **Low values**: Favors generated content. + - **High values**: Favors original content. + """) + gr.Markdown( + f""" + ### {ui_labels.inpaint_detail_preservation} + + This parameter controls how the original latent vectors and denoised latent vectors are interpolated. + With higher values, the magnitude of the resulting blended vector will be closer to the maximum of the two interpolated vectors. + This can prevent the loss of contrast that occurs with linear interpolation. + + - **Low values**: Softer blending, details may fade. + - **High values**: Stronger contrast, may over-saturate colors. + """) + + gr.Markdown( + """ + ## Pixel Composite Settings + + Masks are generated based on how much a part of the image changed after denoising. + These masks are used to blend the original and final images together. + If the difference is low, the original pixels are used instead of the pixels returned by the inpainting process. + """) + + gr.Markdown( + f""" + ### {ui_labels.composite_mask_influence} + + This parameter controls how much the mask should bias this sensitivity to difference. + + - **0**: Ignore the mask, only consider differences in image content. + - **1**: Follow the mask closely despite image content changes. + """) + + gr.Markdown( + f""" + ### {ui_labels.composite_difference_threshold} + + This value represents the difference at which the original pixels will have less than 50% opacity. + + - **Low values**: Two images patches must be almost the same in order to retain original pixels. + - **High values**: Two images patches can be very different and still retain original pixels. + """) + + gr.Markdown( + f""" + ### {ui_labels.composite_difference_contrast} + + This value represents the contrast between the opacity of the original and inpainted content. + + - **Low values**: The blend will be more gradual and have longer transitions, but may cause ghosting. + - **High values**: Ghosting will be less common, but transitions may be very sudden. + """) + + self.infotext_fields = [(soft_inpainting_enabled, enabled_gen_param_label), + (power, gen_param_labels.mask_blend_power), + (scale, gen_param_labels.mask_blend_scale), + (detail, gen_param_labels.inpaint_detail_preservation), + (mask_inf, gen_param_labels.composite_mask_influence), + (dif_thresh, gen_param_labels.composite_difference_threshold), + (dif_contr, gen_param_labels.composite_difference_contrast)] + + self.paste_field_names = [] + for _, field_name in self.infotext_fields: + self.paste_field_names.append(field_name) + + return [soft_inpainting_enabled, + power, + scale, + detail, + mask_inf, + dif_thresh, + dif_contr] + + def process(self, p, enabled, power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr): + if not enabled: + return + + if not processing_uses_inpainting(p): + return + + # Shut off the rounding it normally does. + p.mask_round = False + + settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr) + + # p.extra_generation_params["Mask rounding"] = False + settings.add_generation_params(p.extra_generation_params) + + def on_mask_blend(self, p, mba: scripts.MaskBlendArgs, enabled, power, scale, detail_preservation, mask_inf, + dif_thresh, dif_contr): + if not enabled: + return + + if not processing_uses_inpainting(p): + return + + if mba.is_final_blend: + mba.blended_latent = mba.current_latent + return + + settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr) + + # todo: Why is sigma 2D? Both values are the same. + mba.blended_latent = latent_blend(settings, + mba.init_latent, + mba.current_latent, + get_modified_nmask(settings, mba.nmask, mba.sigma[0])) + + def post_sample(self, p, ps: scripts.PostSampleArgs, enabled, power, scale, detail_preservation, mask_inf, + dif_thresh, dif_contr): + if not enabled: + return + + if not processing_uses_inpainting(p): + return + + nmask = getattr(p, "nmask", None) + if nmask is None: + return + + from modules import images + from modules.shared import opts + + settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr) + + # since the original code puts holes in the existing overlay images, + # we have to rebuild them. + self.overlay_images = [] + for img in p.init_images: + + image = images.flatten(img, opts.img2img_background_color) + + if p.paste_to is None and p.resize_mode != 3: + image = images.resize_image(p.resize_mode, image, p.width, p.height) + + self.overlay_images.append(image.convert('RGBA')) + + if len(p.init_images) == 1: + self.overlay_images = self.overlay_images * p.batch_size + + if getattr(ps.samples, 'already_decoded', False): + self.masks_for_overlay = apply_masks(settings=settings, + nmask=nmask, + overlay_images=self.overlay_images, + width=p.width, + height=p.height, + paste_to=p.paste_to) + else: + self.masks_for_overlay = apply_adaptive_masks(settings=settings, + nmask=nmask, + latent_orig=p.init_latent, + latent_processed=ps.samples, + overlay_images=self.overlay_images, + width=p.width, + height=p.height, + paste_to=p.paste_to) + + def postprocess_maskoverlay(self, p, ppmo: scripts.PostProcessMaskOverlayArgs, enabled, power, scale, + detail_preservation, mask_inf, dif_thresh, dif_contr): + if not enabled: + return + + if not processing_uses_inpainting(p): + return + + if self.masks_for_overlay is None: + return + + if self.overlay_images is None: + return + + ppmo.mask_for_overlay = self.masks_for_overlay[ppmo.index] + ppmo.overlay_image = self.overlay_images[ppmo.index] diff --git a/javascript/edit-attention.js b/javascript/edit-attention.js index 044641006..688c2f112 100644 --- a/javascript/edit-attention.js +++ b/javascript/edit-attention.js @@ -28,7 +28,7 @@ function keyupEditAttention(event) { if (afterParen == -1) return false; let afterOpeningParen = after.indexOf(OPEN); - if (afterOpeningParen != -1 && afterOpeningParen < beforeParen) return false; + if (afterOpeningParen != -1 && afterOpeningParen < afterParen) return false; // Set the selection to the text between the parenthesis const parenContent = text.substring(beforeParen + 1, selectionStart + afterParen); diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index ac26718f6..f1ad19a66 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -26,8 +26,9 @@ function setupExtraNetworksForTab(tabname) { var refresh = gradioApp().getElementById(tabname + '_extra_refresh'); var showDirsDiv = gradioApp().getElementById(tabname + '_extra_show_dirs'); var showDirs = gradioApp().querySelector('#' + tabname + '_extra_show_dirs input'); + var promptContainer = gradioApp().querySelector('.prompt-container-compact#' + tabname + '_prompt_container'); + var negativePrompt = gradioApp().querySelector('#' + tabname + '_neg_prompt'); - sort.dataset.sortkey = 'sortDefault'; tabs.appendChild(searchDiv); tabs.appendChild(sort); tabs.appendChild(sortOrder); @@ -49,20 +50,23 @@ function setupExtraNetworksForTab(tabname) { elem.style.display = visible ? "" : "none"; }); + + applySort(); }; var applySort = function() { + var cards = gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card'); + var reverse = sortOrder.classList.contains("sortReverse"); - var sortKey = sort.querySelector("input").value.toLowerCase().replace("sort", "").replaceAll(" ", "_").replace(/_+$/, "").trim(); - sortKey = sortKey ? "sort" + sortKey.charAt(0).toUpperCase() + sortKey.slice(1) : ""; - var sortKeyStore = sortKey ? sortKey + (reverse ? "Reverse" : "") : ""; - if (!sortKey || sortKeyStore == sort.dataset.sortkey) { + var sortKey = sort.querySelector("input").value.toLowerCase().replace("sort", "").replaceAll(" ", "_").replace(/_+$/, "").trim() || "name"; + sortKey = "sort" + sortKey.charAt(0).toUpperCase() + sortKey.slice(1); + var sortKeyStore = sortKey + "-" + (reverse ? "Descending" : "Ascending") + "-" + cards.length; + + if (sortKeyStore == sort.dataset.sortkey) { return; } - sort.dataset.sortkey = sortKeyStore; - var cards = gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card'); cards.forEach(function(card) { card.originalParentElement = card.parentElement; }); @@ -88,15 +92,13 @@ function setupExtraNetworksForTab(tabname) { }; search.addEventListener("input", applyFilter); - applyFilter(); - ["change", "blur", "click"].forEach(function(evt) { - sort.querySelector("input").addEventListener(evt, applySort); - }); sortOrder.addEventListener("click", function() { sortOrder.classList.toggle("sortReverse"); applySort(); }); + applyFilter(); + extraNetworksApplySort[tabname] = applySort; extraNetworksApplyFilter[tabname] = applyFilter; var showDirsUpdate = function() { @@ -109,11 +111,51 @@ function setupExtraNetworksForTab(tabname) { showDirsUpdate(); } +function extraNetworksMovePromptToTab(tabname, id, showPrompt, showNegativePrompt) { + if (!gradioApp().querySelector('.toprow-compact-tools')) return; // only applicable for compact prompt layout + + var promptContainer = gradioApp().getElementById(tabname + '_prompt_container'); + var prompt = gradioApp().getElementById(tabname + '_prompt_row'); + var negPrompt = gradioApp().getElementById(tabname + '_neg_prompt_row'); + var elem = id ? gradioApp().getElementById(id) : null; + + if (showNegativePrompt && elem) { + elem.insertBefore(negPrompt, elem.firstChild); + } else { + promptContainer.insertBefore(negPrompt, promptContainer.firstChild); + } + + if (showPrompt && elem) { + elem.insertBefore(prompt, elem.firstChild); + } else { + promptContainer.insertBefore(prompt, promptContainer.firstChild); + } + + if (elem) { + elem.classList.toggle('extra-page-prompts-active', showNegativePrompt || showPrompt); + } +} + + +function extraNetworksUrelatedTabSelected(tabname) { // called from python when user selects an unrelated tab (generate) + extraNetworksMovePromptToTab(tabname, '', false, false); +} + +function extraNetworksTabSelected(tabname, id, showPrompt, showNegativePrompt) { // called from python when user selects an extra networks tab + extraNetworksMovePromptToTab(tabname, id, showPrompt, showNegativePrompt); + +} + function applyExtraNetworkFilter(tabname) { setTimeout(extraNetworksApplyFilter[tabname], 1); } +function applyExtraNetworkSort(tabname) { + setTimeout(extraNetworksApplySort[tabname], 1); +} + var extraNetworksApplyFilter = {}; +var extraNetworksApplySort = {}; var activePromptTextarea = {}; function setupExtraNetworks() { @@ -143,8 +185,10 @@ onUiLoaded(setupExtraNetworks); var re_extranet = /<([^:^>]+:[^:]+):[\d.]+>(.*)/; var re_extranet_g = /<([^:^>]+:[^:]+):[\d.]+>/g; -function tryToRemoveExtraNetworkFromPrompt(textarea, text) { - var m = text.match(re_extranet); +var re_extranet_neg = /\(([^:^>]+:[\d.]+)\)/; +var re_extranet_g_neg = /\(([^:^>]+:[\d.]+)\)/g; +function tryToRemoveExtraNetworkFromPrompt(textarea, text, isNeg) { + var m = text.match(isNeg ? re_extranet_neg : re_extranet); var replaced = false; var newTextareaText; if (m) { @@ -152,8 +196,8 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text) { var extraTextAfterNet = m[2]; var partToSearch = m[1]; var foundAtPosition = -1; - newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found, net, pos) { - m = found.match(re_extranet); + newTextareaText = textarea.value.replaceAll(isNeg ? re_extranet_g_neg : re_extranet_g, function(found, net, pos) { + m = found.match(isNeg ? re_extranet_neg : re_extranet); if (m[1] == partToSearch) { replaced = true; foundAtPosition = pos; @@ -163,7 +207,7 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text) { }); if (foundAtPosition >= 0) { - if (newTextareaText.substr(foundAtPosition, extraTextAfterNet.length) == extraTextAfterNet) { + if (extraTextAfterNet && newTextareaText.substr(foundAtPosition, extraTextAfterNet.length) == extraTextAfterNet) { newTextareaText = newTextareaText.substr(0, foundAtPosition) + newTextareaText.substr(foundAtPosition + extraTextAfterNet.length); } if (newTextareaText.substr(foundAtPosition - extraTextBeforeNet.length, extraTextBeforeNet.length) == extraTextBeforeNet) { @@ -188,14 +232,23 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text) { return false; } -function cardClicked(tabname, textToAdd, allowNegativePrompt) { - var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea"); +function updatePromptArea(text, textArea, isNeg) { - if (!tryToRemoveExtraNetworkFromPrompt(textarea, textToAdd)) { - textarea.value = textarea.value + opts.extra_networks_add_text_separator + textToAdd; + if (!tryToRemoveExtraNetworkFromPrompt(textArea, text, isNeg)) { + textArea.value = textArea.value + opts.extra_networks_add_text_separator + text; } - updateInput(textarea); + updateInput(textArea); +} + +function cardClicked(tabname, textToAdd, textToAddNegative, allowNegativePrompt) { + if (textToAddNegative.length > 0) { + updatePromptArea(textToAdd, gradioApp().querySelector("#" + tabname + "_prompt > label > textarea")); + updatePromptArea(textToAddNegative, gradioApp().querySelector("#" + tabname + "_neg_prompt > label > textarea"), true); + } else { + var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea"); + updatePromptArea(textToAdd, textarea); + } } function saveCardPreview(event, tabname, filename) { @@ -350,3 +403,9 @@ function extraNetworksRefreshSingleCard(page, tabname, name) { } }); } + +window.addEventListener("keydown", function(event) { + if (event.key == "Escape") { + closePopup(); + } +}); diff --git a/javascript/imageviewer.js b/javascript/imageviewer.js index e4dae91bc..625c5d148 100644 --- a/javascript/imageviewer.js +++ b/javascript/imageviewer.js @@ -34,7 +34,7 @@ function updateOnBackgroundChange() { if (modalImage && modalImage.offsetParent) { let currentButton = selected_gallery_button(); let preview = gradioApp().querySelectorAll('.livePreview > img'); - if (preview.length > 0) { + if (opts.js_live_preview_in_modal_lightbox && preview.length > 0) { // show preview image if available modalImage.src = preview[preview.length - 1].src; } else if (currentButton?.children?.length > 0 && modalImage.src != currentButton.children[0].src) { diff --git a/javascript/inputAccordion.js b/javascript/inputAccordion.js index f2839852e..7570309aa 100644 --- a/javascript/inputAccordion.js +++ b/javascript/inputAccordion.js @@ -1,37 +1,68 @@ -var observerAccordionOpen = new MutationObserver(function(mutations) { - mutations.forEach(function(mutationRecord) { - var elem = mutationRecord.target; - var open = elem.classList.contains('open'); - - var accordion = elem.parentNode; - accordion.classList.toggle('input-accordion-open', open); - - var checkbox = gradioApp().querySelector('#' + accordion.id + "-checkbox input"); - checkbox.checked = open; - updateInput(checkbox); - - var extra = gradioApp().querySelector('#' + accordion.id + "-extra"); - if (extra) { - extra.style.display = open ? "" : "none"; - } - }); -}); - function inputAccordionChecked(id, checked) { - var label = gradioApp().querySelector('#' + id + " .label-wrap"); - if (label.classList.contains('open') != checked) { - label.click(); + var accordion = gradioApp().getElementById(id); + accordion.visibleCheckbox.checked = checked; + accordion.onVisibleCheckboxChange(); +} + +function setupAccordion(accordion) { + var labelWrap = accordion.querySelector('.label-wrap'); + var gradioCheckbox = gradioApp().querySelector('#' + accordion.id + "-checkbox input"); + var extra = gradioApp().querySelector('#' + accordion.id + "-extra"); + var span = labelWrap.querySelector('span'); + var linked = true; + + var isOpen = function() { + return labelWrap.classList.contains('open'); + }; + + var observerAccordionOpen = new MutationObserver(function(mutations) { + mutations.forEach(function(mutationRecord) { + accordion.classList.toggle('input-accordion-open', isOpen()); + + if (linked) { + accordion.visibleCheckbox.checked = isOpen(); + accordion.onVisibleCheckboxChange(); + } + }); + }); + observerAccordionOpen.observe(labelWrap, {attributes: true, attributeFilter: ['class']}); + + if (extra) { + labelWrap.insertBefore(extra, labelWrap.lastElementChild); } + + accordion.onChecked = function(checked) { + if (isOpen() != checked) { + labelWrap.click(); + } + }; + + var visibleCheckbox = document.createElement('INPUT'); + visibleCheckbox.type = 'checkbox'; + visibleCheckbox.checked = isOpen(); + visibleCheckbox.id = accordion.id + "-visible-checkbox"; + visibleCheckbox.className = gradioCheckbox.className + " input-accordion-checkbox"; + span.insertBefore(visibleCheckbox, span.firstChild); + + accordion.visibleCheckbox = visibleCheckbox; + accordion.onVisibleCheckboxChange = function() { + if (linked && isOpen() != visibleCheckbox.checked) { + labelWrap.click(); + } + + gradioCheckbox.checked = visibleCheckbox.checked; + updateInput(gradioCheckbox); + }; + + visibleCheckbox.addEventListener('click', function(event) { + linked = false; + event.stopPropagation(); + }); + visibleCheckbox.addEventListener('input', accordion.onVisibleCheckboxChange); } onUiLoaded(function() { for (var accordion of gradioApp().querySelectorAll('.input-accordion')) { - var labelWrap = accordion.querySelector('.label-wrap'); - observerAccordionOpen.observe(labelWrap, {attributes: true, attributeFilter: ['class']}); - - var extra = gradioApp().querySelector('#' + accordion.id + "-extra"); - if (extra) { - labelWrap.insertBefore(extra, labelWrap.lastElementChild); - } + setupAccordion(accordion); } }); diff --git a/javascript/notification.js b/javascript/notification.js index 6d7995612..3ee972ae1 100644 --- a/javascript/notification.js +++ b/javascript/notification.js @@ -26,7 +26,11 @@ onAfterUiUpdate(function() { lastHeadImg = headImg; // play notification sound if available - gradioApp().querySelector('#audio_notification audio')?.play(); + const notificationAudio = gradioApp().querySelector('#audio_notification audio'); + if (notificationAudio) { + notificationAudio.volume = opts.notification_volume / 100.0 || 1.0; + notificationAudio.play(); + } if (document.hasFocus()) return; diff --git a/javascript/settings.js b/javascript/settings.js index 4e79ec003..e6009290a 100644 --- a/javascript/settings.js +++ b/javascript/settings.js @@ -44,3 +44,28 @@ onUiLoaded(function() { buttonShowAllPages.addEventListener("click", settingsShowAllTabs); }); + + +onOptionsChanged(function() { + if (gradioApp().querySelector('#settings .settings-category')) return; + + var sectionMap = {}; + gradioApp().querySelectorAll('#settings > div > button').forEach(function(x) { + sectionMap[x.textContent.trim()] = x; + }); + + opts._categories.forEach(function(x) { + var section = x[0]; + var category = x[1]; + + var span = document.createElement('SPAN'); + span.textContent = category; + span.className = 'settings-category'; + + var sectionElem = sectionMap[section]; + if (!sectionElem) return; + + sectionElem.parentElement.insertBefore(span, sectionElem); + }); +}); + diff --git a/javascript/ui.js b/javascript/ui.js index 2e2626020..18c9f891a 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -170,6 +170,23 @@ function submit_img2img() { return res; } +function submit_extras() { + showSubmitButtons('extras', false); + + var id = randomId(); + + requestProgress(id, gradioApp().getElementById('extras_gallery_container'), gradioApp().getElementById('extras_gallery'), function() { + showSubmitButtons('extras', true); + }); + + var res = create_submit_args(arguments); + + res[0] = id; + + console.log(res); + return res; +} + function restoreProgressTxt2img() { showRestoreProgressButton("txt2img", false); var id = localGet("txt2img_task_id"); @@ -198,9 +215,33 @@ function restoreProgressImg2img() { } +/** + * Configure the width and height elements on `tabname` to accept + * pasting of resolutions in the form of "width x height". + */ +function setupResolutionPasting(tabname) { + var width = gradioApp().querySelector(`#${tabname}_width input[type=number]`); + var height = gradioApp().querySelector(`#${tabname}_height input[type=number]`); + for (const el of [width, height]) { + el.addEventListener('paste', function(event) { + var pasteData = event.clipboardData.getData('text/plain'); + var parsed = pasteData.match(/^\s*(\d+)\D+(\d+)\s*$/); + if (parsed) { + width.value = parsed[1]; + height.value = parsed[2]; + updateInput(width); + updateInput(height); + event.preventDefault(); + } + }); + } +} + onUiLoaded(function() { showRestoreProgressButton('txt2img', localGet("txt2img_task_id")); showRestoreProgressButton('img2img', localGet("img2img_task_id")); + setupResolutionPasting('txt2img'); + setupResolutionPasting('img2img'); }); diff --git a/modules/api/api.py b/modules/api/api.py index 090838747..0e2807de2 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -17,12 +17,11 @@ from fastapi.encoders import jsonable_encoder from secrets import compare_digest import modules.shared as shared -from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, generation_parameters_copypaste, sd_models +from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, infotext, sd_models from modules.api import models from modules.shared import opts from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.textual_inversion.textual_inversion import create_embedding, train_embedding -from modules.textual_inversion.preprocess import preprocess from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork from PIL import PngImagePlugin, Image from modules.sd_models_config import find_checkpoint_config_near_filename @@ -32,7 +31,7 @@ from typing import Any import piexif import piexif.helper from contextlib import closing - +from modules.progress import create_task_id, add_task_to_queue, start_task, finish_task, current_task def script_name_to_index(name, scripts): try: @@ -235,7 +234,6 @@ class Api: self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"]) self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse) self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse) - self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse) self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse) self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse) self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse) @@ -253,6 +251,24 @@ class Api: self.default_script_arg_txt2img = [] self.default_script_arg_img2img = [] + txt2img_script_runner = scripts.scripts_txt2img + img2img_script_runner = scripts.scripts_img2img + + if not txt2img_script_runner.scripts or not img2img_script_runner.scripts: + ui.create_ui() + + if not txt2img_script_runner.scripts: + txt2img_script_runner.initialize_scripts(False) + if not self.default_script_arg_txt2img: + self.default_script_arg_txt2img = self.init_default_script_args(txt2img_script_runner) + + if not img2img_script_runner.scripts: + img2img_script_runner.initialize_scripts(True) + if not self.default_script_arg_img2img: + self.default_script_arg_img2img = self.init_default_script_args(img2img_script_runner) + + + def add_api_route(self, path: str, endpoint, **kwargs): if shared.cmd_opts.api_auth: return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs) @@ -314,8 +330,13 @@ class Api: script_args[script.args_from:script.args_to] = ui_default_values return script_args - def init_script_args(self, request, default_script_args, selectable_scripts, selectable_idx, script_runner): + def init_script_args(self, request, default_script_args, selectable_scripts, selectable_idx, script_runner, *, input_script_args=None): script_args = default_script_args.copy() + + if input_script_args is not None: + for index, value in input_script_args.items(): + script_args[index] = value + # position 0 in script_arg is the idx+1 of the selectable script that is going to be run when using scripts.scripts_*2img.run() if selectable_scripts: script_args[selectable_scripts.args_from:selectable_scripts.args_to] = request.script_args @@ -337,13 +358,83 @@ class Api: script_args[alwayson_script.args_from + idx] = request.alwayson_scripts[alwayson_script_name]["args"][idx] return script_args + def apply_infotext(self, request, tabname, *, script_runner=None, mentioned_script_args=None): + """Processes `infotext` field from the `request`, and sets other fields of the `request` accoring to what's in infotext. + + If request already has a field set, and that field is encountered in infotext too, the value from infotext is ignored. + + Additionally, fills `mentioned_script_args` dict with index: value pairs for script arguments read from infotext. + """ + + if not request.infotext: + return {} + + possible_fields = infotext.paste_fields[tabname]["fields"] + set_fields = request.model_dump(exclude_unset=True) if hasattr(request, "request") else request.dict(exclude_unset=True) # pydantic v1/v2 have differenrt names for this + params = infotext.parse_generation_parameters(request.infotext) + + def get_field_value(field, params): + value = field.function(params) if field.function else params.get(field.label) + if value is None: + return None + + if field.api in request.__fields__: + target_type = request.__fields__[field.api].type_ + else: + target_type = type(field.component.value) + + if target_type == type(None): + return None + + if isinstance(value, dict) and value.get('__type__') == 'generic_update': # this is a gradio.update rather than a value + value = value.get('value') + + if value is not None and not isinstance(value, target_type): + value = target_type(value) + + return value + + for field in possible_fields: + if not field.api: + continue + + if field.api in set_fields: + continue + + value = get_field_value(field, params) + if value is not None: + setattr(request, field.api, value) + + if request.override_settings is None: + request.override_settings = {} + + overriden_settings = infotext.get_override_settings(params) + for _, setting_name, value in overriden_settings: + if setting_name not in request.override_settings: + request.override_settings[setting_name] = value + + if script_runner is not None and mentioned_script_args is not None: + indexes = {v: i for i, v in enumerate(script_runner.inputs)} + script_fields = ((field, indexes[field.component]) for field in possible_fields if field.component in indexes) + + for field, index in script_fields: + value = get_field_value(field, params) + + if value is None: + continue + + mentioned_script_args[index] = value + + return params + def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): + task_id = txt2imgreq.force_task_id or create_task_id("txt2img") + script_runner = scripts.scripts_txt2img - if not script_runner.scripts: - script_runner.initialize_scripts(False) - ui.create_ui() - if not self.default_script_arg_txt2img: - self.default_script_arg_txt2img = self.init_default_script_args(script_runner) + + infotext_script_args = {} + self.apply_infotext(txt2imgreq, "txt2img", script_runner=script_runner, mentioned_script_args=infotext_script_args) + selectable_scripts, selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner) populate = txt2imgreq.copy(update={ # Override __init__ params @@ -358,12 +449,15 @@ class Api: args.pop('script_name', None) args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them args.pop('alwayson_scripts', None) + args.pop('infotext', None) - script_args = self.init_script_args(txt2imgreq, self.default_script_arg_txt2img, selectable_scripts, selectable_script_idx, script_runner) + script_args = self.init_script_args(txt2imgreq, self.default_script_arg_txt2img, selectable_scripts, selectable_script_idx, script_runner, input_script_args=infotext_script_args) send_images = args.pop('send_images', True) args.pop('save_images', None) + add_task_to_queue(task_id) + with self.queue_lock: with closing(StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)) as p: p.is_api = True @@ -373,12 +467,14 @@ class Api: try: shared.state.begin(job="scripts_txt2img") + start_task(task_id) if selectable_scripts is not None: p.script_args = script_args processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here else: p.script_args = tuple(script_args) # Need to pass args as tuple here processed = process_images(p) + finish_task(task_id) finally: shared.state.end() shared.total_tqdm.clear() @@ -388,6 +484,8 @@ class Api: return models.TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js()) def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI): + task_id = img2imgreq.force_task_id or create_task_id("img2img") + init_images = img2imgreq.init_images if init_images is None: raise HTTPException(status_code=404, detail="Init image not found") @@ -397,11 +495,10 @@ class Api: mask = decode_base64_to_image(mask) script_runner = scripts.scripts_img2img - if not script_runner.scripts: - script_runner.initialize_scripts(True) - ui.create_ui() - if not self.default_script_arg_img2img: - self.default_script_arg_img2img = self.init_default_script_args(script_runner) + + infotext_script_args = {} + self.apply_infotext(img2imgreq, "img2img", script_runner=script_runner, mentioned_script_args=infotext_script_args) + selectable_scripts, selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner) populate = img2imgreq.copy(update={ # Override __init__ params @@ -418,12 +515,15 @@ class Api: args.pop('script_name', None) args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them args.pop('alwayson_scripts', None) + args.pop('infotext', None) - script_args = self.init_script_args(img2imgreq, self.default_script_arg_img2img, selectable_scripts, selectable_script_idx, script_runner) + script_args = self.init_script_args(img2imgreq, self.default_script_arg_img2img, selectable_scripts, selectable_script_idx, script_runner, input_script_args=infotext_script_args) send_images = args.pop('send_images', True) args.pop('save_images', None) + add_task_to_queue(task_id) + with self.queue_lock: with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p: p.init_images = [decode_base64_to_image(x) for x in init_images] @@ -434,12 +534,14 @@ class Api: try: shared.state.begin(job="scripts_img2img") + start_task(task_id) if selectable_scripts is not None: p.script_args = script_args processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here else: p.script_args = tuple(script_args) # Need to pass args as tuple here processed = process_images(p) + finish_task(task_id) finally: shared.state.end() shared.total_tqdm.clear() @@ -482,7 +584,7 @@ class Api: if geninfo is None: geninfo = "" - params = generation_parameters_copypaste.parse_generation_parameters(geninfo) + params = infotext.parse_generation_parameters(geninfo) script_callbacks.infotext_pasted_callback(geninfo, params) return models.PNGInfoResponse(info=geninfo, items=items, parameters=params) @@ -513,7 +615,7 @@ class Api: if shared.state.current_image and not req.skip_current_image: current_image = encode_pil_to_base64(shared.state.current_image) - return models.ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo) + return models.ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo, current_task=current_task) def interrogateapi(self, interrogatereq: models.InterrogateRequest): image_b64 = interrogatereq.image @@ -675,19 +777,6 @@ class Api: finally: shared.state.end() - def preprocess(self, args: dict): - try: - shared.state.begin(job="preprocess") - preprocess(**args) # quick operation unless blip/booru interrogation is enabled - shared.state.end() - return models.PreprocessResponse(info='preprocess complete') - except KeyError as e: - return models.PreprocessResponse(info=f"preprocess error: invalid token: {e}") - except Exception as e: - return models.PreprocessResponse(info=f"preprocess error: {e}") - finally: - shared.state.end() - def train_embedding(self, args: dict): try: shared.state.begin(job="train_embedding") diff --git a/modules/api/models.py b/modules/api/models.py index a0d80af8c..16edf11cf 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -107,6 +107,8 @@ StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator( {"key": "send_images", "type": bool, "default": True}, {"key": "save_images", "type": bool, "default": False}, {"key": "alwayson_scripts", "type": dict, "default": {}}, + {"key": "force_task_id", "type": str, "default": None}, + {"key": "infotext", "type": str, "default": None}, ] ).generate_model() @@ -124,6 +126,8 @@ StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator( {"key": "send_images", "type": bool, "default": True}, {"key": "save_images", "type": bool, "default": False}, {"key": "alwayson_scripts", "type": dict, "default": {}}, + {"key": "force_task_id", "type": str, "default": None}, + {"key": "infotext", "type": str, "default": None}, ] ).generate_model() @@ -202,9 +206,6 @@ class TrainResponse(BaseModel): class CreateResponse(BaseModel): info: str = Field(title="Create info", description="Response string from create embedding or hypernetwork task.") -class PreprocessResponse(BaseModel): - info: str = Field(title="Preprocess info", description="Response string from preprocessing task.") - fields = {} for key, metadata in opts.data_labels.items(): value = opts.data.get(key) diff --git a/modules/cache.py b/modules/cache.py index ff26a2132..2d37e7b99 100644 --- a/modules/cache.py +++ b/modules/cache.py @@ -32,7 +32,7 @@ def dump_cache(): with cache_lock: cache_filename_tmp = cache_filename + "-" with open(cache_filename_tmp, "w", encoding="utf8") as file: - json.dump(cache_data, file, indent=4) + json.dump(cache_data, file, indent=4, ensure_ascii=False) os.replace(cache_filename_tmp, cache_filename) diff --git a/modules/call_queue.py b/modules/call_queue.py index ddf0d5738..bcd7c5462 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -78,6 +78,7 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False): shared.state.skipped = False shared.state.interrupted = False + shared.state.stopping_generation = False shared.state.job_count = 0 if not add_stats: diff --git a/modules/cmd_args.py b/modules/cmd_args.py index c6d4b6123..3775e6702 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -70,6 +70,7 @@ parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="pre parser.add_argument("--disable-opt-split-attention", action='store_true', help="prefer no cross-attention layer optimization for automatic choice of optimization") parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI") parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower) +parser.add_argument("--use-ipex", action="store_true", help="use Intel XPU as torch device") parser.add_argument("--disable-model-loading-ram-optimization", action='store_true', help="disable an optimization that reduces RAM use when loading a model") parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests") parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None) @@ -109,7 +110,7 @@ parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, req parser.add_argument("--disable-tls-verify", action="store_false", help="When passed, enables the use of self-signed certificates.", default=None) parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None) parser.add_argument("--gradio-queue", action='store_true', help="does not do anything", default=True) -parser.add_argument("--no-gradio-queue", action='store_true', help="Disables gradio queue; causes the webpage to use http requests instead of websockets; was the defaul in earlier versions") +parser.add_argument("--no-gradio-queue", action='store_true', help="Disables gradio queue; causes the webpage to use http requests instead of websockets; was the default in earlier versions") parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers") parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False) parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False) diff --git a/modules/codeformer/codeformer_arch.py b/modules/codeformer/codeformer_arch.py deleted file mode 100644 index 12db68142..000000000 --- a/modules/codeformer/codeformer_arch.py +++ /dev/null @@ -1,276 +0,0 @@ -# this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py - -import math -import torch -from torch import nn, Tensor -import torch.nn.functional as F -from typing import Optional - -from modules.codeformer.vqgan_arch import VQAutoEncoder, ResBlock -from basicsr.utils.registry import ARCH_REGISTRY - -def calc_mean_std(feat, eps=1e-5): - """Calculate mean and std for adaptive_instance_normalization. - - Args: - feat (Tensor): 4D tensor. - eps (float): A small value added to the variance to avoid - divide-by-zero. Default: 1e-5. - """ - size = feat.size() - assert len(size) == 4, 'The input feature should be 4D tensor.' - b, c = size[:2] - feat_var = feat.view(b, c, -1).var(dim=2) + eps - feat_std = feat_var.sqrt().view(b, c, 1, 1) - feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1) - return feat_mean, feat_std - - -def adaptive_instance_normalization(content_feat, style_feat): - """Adaptive instance normalization. - - Adjust the reference features to have the similar color and illuminations - as those in the degradate features. - - Args: - content_feat (Tensor): The reference feature. - style_feat (Tensor): The degradate features. - """ - size = content_feat.size() - style_mean, style_std = calc_mean_std(style_feat) - content_mean, content_std = calc_mean_std(content_feat) - normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size) - return normalized_feat * style_std.expand(size) + style_mean.expand(size) - - -class PositionEmbeddingSine(nn.Module): - """ - This is a more standard version of the position embedding, very similar to the one - used by the Attention is all you need paper, generalized to work on images. - """ - - def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): - super().__init__() - self.num_pos_feats = num_pos_feats - self.temperature = temperature - self.normalize = normalize - if scale is not None and normalize is False: - raise ValueError("normalize should be True if scale is passed") - if scale is None: - scale = 2 * math.pi - self.scale = scale - - def forward(self, x, mask=None): - if mask is None: - mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) - not_mask = ~mask - y_embed = not_mask.cumsum(1, dtype=torch.float32) - x_embed = not_mask.cumsum(2, dtype=torch.float32) - if self.normalize: - eps = 1e-6 - y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale - x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale - - dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) - dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) - - pos_x = x_embed[:, :, :, None] / dim_t - pos_y = y_embed[:, :, :, None] / dim_t - pos_x = torch.stack( - (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 - ).flatten(3) - pos_y = torch.stack( - (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 - ).flatten(3) - pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) - return pos - -def _get_activation_fn(activation): - """Return an activation function given a string""" - if activation == "relu": - return F.relu - if activation == "gelu": - return F.gelu - if activation == "glu": - return F.glu - raise RuntimeError(F"activation should be relu/gelu, not {activation}.") - - -class TransformerSALayer(nn.Module): - def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"): - super().__init__() - self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout) - # Implementation of Feedforward model - MLP - self.linear1 = nn.Linear(embed_dim, dim_mlp) - self.dropout = nn.Dropout(dropout) - self.linear2 = nn.Linear(dim_mlp, embed_dim) - - self.norm1 = nn.LayerNorm(embed_dim) - self.norm2 = nn.LayerNorm(embed_dim) - self.dropout1 = nn.Dropout(dropout) - self.dropout2 = nn.Dropout(dropout) - - self.activation = _get_activation_fn(activation) - - def with_pos_embed(self, tensor, pos: Optional[Tensor]): - return tensor if pos is None else tensor + pos - - def forward(self, tgt, - tgt_mask: Optional[Tensor] = None, - tgt_key_padding_mask: Optional[Tensor] = None, - query_pos: Optional[Tensor] = None): - - # self attention - tgt2 = self.norm1(tgt) - q = k = self.with_pos_embed(tgt2, query_pos) - tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, - key_padding_mask=tgt_key_padding_mask)[0] - tgt = tgt + self.dropout1(tgt2) - - # ffn - tgt2 = self.norm2(tgt) - tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) - tgt = tgt + self.dropout2(tgt2) - return tgt - -class Fuse_sft_block(nn.Module): - def __init__(self, in_ch, out_ch): - super().__init__() - self.encode_enc = ResBlock(2*in_ch, out_ch) - - self.scale = nn.Sequential( - nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), - nn.LeakyReLU(0.2, True), - nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)) - - self.shift = nn.Sequential( - nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), - nn.LeakyReLU(0.2, True), - nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)) - - def forward(self, enc_feat, dec_feat, w=1): - enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1)) - scale = self.scale(enc_feat) - shift = self.shift(enc_feat) - residual = w * (dec_feat * scale + shift) - out = dec_feat + residual - return out - - -@ARCH_REGISTRY.register() -class CodeFormer(VQAutoEncoder): - def __init__(self, dim_embd=512, n_head=8, n_layers=9, - codebook_size=1024, latent_size=256, - connect_list=('32', '64', '128', '256'), - fix_modules=('quantize', 'generator')): - super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size) - - if fix_modules is not None: - for module in fix_modules: - for param in getattr(self, module).parameters(): - param.requires_grad = False - - self.connect_list = connect_list - self.n_layers = n_layers - self.dim_embd = dim_embd - self.dim_mlp = dim_embd*2 - - self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd)) - self.feat_emb = nn.Linear(256, self.dim_embd) - - # transformer - self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0) - for _ in range(self.n_layers)]) - - # logits_predict head - self.idx_pred_layer = nn.Sequential( - nn.LayerNorm(dim_embd), - nn.Linear(dim_embd, codebook_size, bias=False)) - - self.channels = { - '16': 512, - '32': 256, - '64': 256, - '128': 128, - '256': 128, - '512': 64, - } - - # after second residual block for > 16, before attn layer for ==16 - self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18} - # after first residual block for > 16, before attn layer for ==16 - self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21} - - # fuse_convs_dict - self.fuse_convs_dict = nn.ModuleDict() - for f_size in self.connect_list: - in_ch = self.channels[f_size] - self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch) - - def _init_weights(self, module): - if isinstance(module, (nn.Linear, nn.Embedding)): - module.weight.data.normal_(mean=0.0, std=0.02) - if isinstance(module, nn.Linear) and module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - - def forward(self, x, w=0, detach_16=True, code_only=False, adain=False): - # ################### Encoder ##################### - enc_feat_dict = {} - out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list] - for i, block in enumerate(self.encoder.blocks): - x = block(x) - if i in out_list: - enc_feat_dict[str(x.shape[-1])] = x.clone() - - lq_feat = x - # ################# Transformer ################### - # quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat) - pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1) - # BCHW -> BC(HW) -> (HW)BC - feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1)) - query_emb = feat_emb - # Transformer encoder - for layer in self.ft_layers: - query_emb = layer(query_emb, query_pos=pos_emb) - - # output logits - logits = self.idx_pred_layer(query_emb) # (hw)bn - logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n - - if code_only: # for training stage II - # logits doesn't need softmax before cross_entropy loss - return logits, lq_feat - - # ################# Quantization ################### - # if self.training: - # quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight]) - # # b(hw)c -> bc(hw) -> bchw - # quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape) - # ------------ - soft_one_hot = F.softmax(logits, dim=2) - _, top_idx = torch.topk(soft_one_hot, 1, dim=2) - quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256]) - # preserve gradients - # quant_feat = lq_feat + (quant_feat - lq_feat).detach() - - if detach_16: - quant_feat = quant_feat.detach() # for training stage III - if adain: - quant_feat = adaptive_instance_normalization(quant_feat, lq_feat) - - # ################## Generator #################### - x = quant_feat - fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list] - - for i, block in enumerate(self.generator.blocks): - x = block(x) - if i in fuse_list: # fuse after i-th block - f_size = str(x.shape[-1]) - if w>0: - x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w) - out = x - # logits doesn't need softmax before cross_entropy loss - return out, logits, lq_feat diff --git a/modules/codeformer/vqgan_arch.py b/modules/codeformer/vqgan_arch.py deleted file mode 100644 index 09ee6660d..000000000 --- a/modules/codeformer/vqgan_arch.py +++ /dev/null @@ -1,435 +0,0 @@ -# this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py - -''' -VQGAN code, adapted from the original created by the Unleashing Transformers authors: -https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py - -''' -import torch -import torch.nn as nn -import torch.nn.functional as F -from basicsr.utils import get_root_logger -from basicsr.utils.registry import ARCH_REGISTRY - -def normalize(in_channels): - return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) - - -@torch.jit.script -def swish(x): - return x*torch.sigmoid(x) - - -# Define VQVAE classes -class VectorQuantizer(nn.Module): - def __init__(self, codebook_size, emb_dim, beta): - super(VectorQuantizer, self).__init__() - self.codebook_size = codebook_size # number of embeddings - self.emb_dim = emb_dim # dimension of embedding - self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2 - self.embedding = nn.Embedding(self.codebook_size, self.emb_dim) - self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size) - - def forward(self, z): - # reshape z -> (batch, height, width, channel) and flatten - z = z.permute(0, 2, 3, 1).contiguous() - z_flattened = z.view(-1, self.emb_dim) - - # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z - d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \ - 2 * torch.matmul(z_flattened, self.embedding.weight.t()) - - mean_distance = torch.mean(d) - # find closest encodings - # min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1) - min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False) - # [0-1], higher score, higher confidence - min_encoding_scores = torch.exp(-min_encoding_scores/10) - - min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z) - min_encodings.scatter_(1, min_encoding_indices, 1) - - # get quantized latent vectors - z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape) - # compute loss for embedding - loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2) - # preserve gradients - z_q = z + (z_q - z).detach() - - # perplexity - e_mean = torch.mean(min_encodings, dim=0) - perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10))) - # reshape back to match original input shape - z_q = z_q.permute(0, 3, 1, 2).contiguous() - - return z_q, loss, { - "perplexity": perplexity, - "min_encodings": min_encodings, - "min_encoding_indices": min_encoding_indices, - "min_encoding_scores": min_encoding_scores, - "mean_distance": mean_distance - } - - def get_codebook_feat(self, indices, shape): - # input indices: batch*token_num -> (batch*token_num)*1 - # shape: batch, height, width, channel - indices = indices.view(-1,1) - min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices) - min_encodings.scatter_(1, indices, 1) - # get quantized latent vectors - z_q = torch.matmul(min_encodings.float(), self.embedding.weight) - - if shape is not None: # reshape back to match original input shape - z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous() - - return z_q - - -class GumbelQuantizer(nn.Module): - def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0): - super().__init__() - self.codebook_size = codebook_size # number of embeddings - self.emb_dim = emb_dim # dimension of embedding - self.straight_through = straight_through - self.temperature = temp_init - self.kl_weight = kl_weight - self.proj = nn.Conv2d(num_hiddens, codebook_size, 1) # projects last encoder layer to quantized logits - self.embed = nn.Embedding(codebook_size, emb_dim) - - def forward(self, z): - hard = self.straight_through if self.training else True - - logits = self.proj(z) - - soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard) - - z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight) - - # + kl divergence to the prior loss - qy = F.softmax(logits, dim=1) - diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean() - min_encoding_indices = soft_one_hot.argmax(dim=1) - - return z_q, diff, { - "min_encoding_indices": min_encoding_indices - } - - -class Downsample(nn.Module): - def __init__(self, in_channels): - super().__init__() - self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) - - def forward(self, x): - pad = (0, 1, 0, 1) - x = torch.nn.functional.pad(x, pad, mode="constant", value=0) - x = self.conv(x) - return x - - -class Upsample(nn.Module): - def __init__(self, in_channels): - super().__init__() - self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) - - def forward(self, x): - x = F.interpolate(x, scale_factor=2.0, mode="nearest") - x = self.conv(x) - - return x - - -class ResBlock(nn.Module): - def __init__(self, in_channels, out_channels=None): - super(ResBlock, self).__init__() - self.in_channels = in_channels - self.out_channels = in_channels if out_channels is None else out_channels - self.norm1 = normalize(in_channels) - self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) - self.norm2 = normalize(out_channels) - self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) - if self.in_channels != self.out_channels: - self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) - - def forward(self, x_in): - x = x_in - x = self.norm1(x) - x = swish(x) - x = self.conv1(x) - x = self.norm2(x) - x = swish(x) - x = self.conv2(x) - if self.in_channels != self.out_channels: - x_in = self.conv_out(x_in) - - return x + x_in - - -class AttnBlock(nn.Module): - def __init__(self, in_channels): - super().__init__() - self.in_channels = in_channels - - self.norm = normalize(in_channels) - self.q = torch.nn.Conv2d( - in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0 - ) - self.k = torch.nn.Conv2d( - in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0 - ) - self.v = torch.nn.Conv2d( - in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0 - ) - self.proj_out = torch.nn.Conv2d( - in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0 - ) - - def forward(self, x): - h_ = x - h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - - # compute attention - b, c, h, w = q.shape - q = q.reshape(b, c, h*w) - q = q.permute(0, 2, 1) - k = k.reshape(b, c, h*w) - w_ = torch.bmm(q, k) - w_ = w_ * (int(c)**(-0.5)) - w_ = F.softmax(w_, dim=2) - - # attend to values - v = v.reshape(b, c, h*w) - w_ = w_.permute(0, 2, 1) - h_ = torch.bmm(v, w_) - h_ = h_.reshape(b, c, h, w) - - h_ = self.proj_out(h_) - - return x+h_ - - -class Encoder(nn.Module): - def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions): - super().__init__() - self.nf = nf - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.attn_resolutions = attn_resolutions - - curr_res = self.resolution - in_ch_mult = (1,)+tuple(ch_mult) - - blocks = [] - # initial convultion - blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1)) - - # residual and downsampling blocks, with attention on smaller res (16x16) - for i in range(self.num_resolutions): - block_in_ch = nf * in_ch_mult[i] - block_out_ch = nf * ch_mult[i] - for _ in range(self.num_res_blocks): - blocks.append(ResBlock(block_in_ch, block_out_ch)) - block_in_ch = block_out_ch - if curr_res in attn_resolutions: - blocks.append(AttnBlock(block_in_ch)) - - if i != self.num_resolutions - 1: - blocks.append(Downsample(block_in_ch)) - curr_res = curr_res // 2 - - # non-local attention block - blocks.append(ResBlock(block_in_ch, block_in_ch)) - blocks.append(AttnBlock(block_in_ch)) - blocks.append(ResBlock(block_in_ch, block_in_ch)) - - # normalise and convert to latent size - blocks.append(normalize(block_in_ch)) - blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1)) - self.blocks = nn.ModuleList(blocks) - - def forward(self, x): - for block in self.blocks: - x = block(x) - - return x - - -class Generator(nn.Module): - def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions): - super().__init__() - self.nf = nf - self.ch_mult = ch_mult - self.num_resolutions = len(self.ch_mult) - self.num_res_blocks = res_blocks - self.resolution = img_size - self.attn_resolutions = attn_resolutions - self.in_channels = emb_dim - self.out_channels = 3 - block_in_ch = self.nf * self.ch_mult[-1] - curr_res = self.resolution // 2 ** (self.num_resolutions-1) - - blocks = [] - # initial conv - blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1)) - - # non-local attention block - blocks.append(ResBlock(block_in_ch, block_in_ch)) - blocks.append(AttnBlock(block_in_ch)) - blocks.append(ResBlock(block_in_ch, block_in_ch)) - - for i in reversed(range(self.num_resolutions)): - block_out_ch = self.nf * self.ch_mult[i] - - for _ in range(self.num_res_blocks): - blocks.append(ResBlock(block_in_ch, block_out_ch)) - block_in_ch = block_out_ch - - if curr_res in self.attn_resolutions: - blocks.append(AttnBlock(block_in_ch)) - - if i != 0: - blocks.append(Upsample(block_in_ch)) - curr_res = curr_res * 2 - - blocks.append(normalize(block_in_ch)) - blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1)) - - self.blocks = nn.ModuleList(blocks) - - - def forward(self, x): - for block in self.blocks: - x = block(x) - - return x - - -@ARCH_REGISTRY.register() -class VQAutoEncoder(nn.Module): - def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=None, codebook_size=1024, emb_dim=256, - beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None): - super().__init__() - logger = get_root_logger() - self.in_channels = 3 - self.nf = nf - self.n_blocks = res_blocks - self.codebook_size = codebook_size - self.embed_dim = emb_dim - self.ch_mult = ch_mult - self.resolution = img_size - self.attn_resolutions = attn_resolutions or [16] - self.quantizer_type = quantizer - self.encoder = Encoder( - self.in_channels, - self.nf, - self.embed_dim, - self.ch_mult, - self.n_blocks, - self.resolution, - self.attn_resolutions - ) - if self.quantizer_type == "nearest": - self.beta = beta #0.25 - self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta) - elif self.quantizer_type == "gumbel": - self.gumbel_num_hiddens = emb_dim - self.straight_through = gumbel_straight_through - self.kl_weight = gumbel_kl_weight - self.quantize = GumbelQuantizer( - self.codebook_size, - self.embed_dim, - self.gumbel_num_hiddens, - self.straight_through, - self.kl_weight - ) - self.generator = Generator( - self.nf, - self.embed_dim, - self.ch_mult, - self.n_blocks, - self.resolution, - self.attn_resolutions - ) - - if model_path is not None: - chkpt = torch.load(model_path, map_location='cpu') - if 'params_ema' in chkpt: - self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema']) - logger.info(f'vqgan is loaded from: {model_path} [params_ema]') - elif 'params' in chkpt: - self.load_state_dict(torch.load(model_path, map_location='cpu')['params']) - logger.info(f'vqgan is loaded from: {model_path} [params]') - else: - raise ValueError('Wrong params!') - - - def forward(self, x): - x = self.encoder(x) - quant, codebook_loss, quant_stats = self.quantize(x) - x = self.generator(quant) - return x, codebook_loss, quant_stats - - - -# patch based discriminator -@ARCH_REGISTRY.register() -class VQGANDiscriminator(nn.Module): - def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None): - super().__init__() - - layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)] - ndf_mult = 1 - ndf_mult_prev = 1 - for n in range(1, n_layers): # gradually increase the number of filters - ndf_mult_prev = ndf_mult - ndf_mult = min(2 ** n, 8) - layers += [ - nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False), - nn.BatchNorm2d(ndf * ndf_mult), - nn.LeakyReLU(0.2, True) - ] - - ndf_mult_prev = ndf_mult - ndf_mult = min(2 ** n_layers, 8) - - layers += [ - nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False), - nn.BatchNorm2d(ndf * ndf_mult), - nn.LeakyReLU(0.2, True) - ] - - layers += [ - nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map - self.main = nn.Sequential(*layers) - - if model_path is not None: - chkpt = torch.load(model_path, map_location='cpu') - if 'params_d' in chkpt: - self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d']) - elif 'params' in chkpt: - self.load_state_dict(torch.load(model_path, map_location='cpu')['params']) - else: - raise ValueError('Wrong params!') - - def forward(self, x): - return self.main(x) diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py index da42b5e99..44b84618e 100644 --- a/modules/codeformer_model.py +++ b/modules/codeformer_model.py @@ -1,132 +1,64 @@ -import os +from __future__ import annotations + +import logging -import cv2 import torch -import modules.face_restoration -import modules.shared -from modules import shared, devices, modelloader, errors -from modules.paths import models_path +from modules import ( + devices, + errors, + face_restoration, + face_restoration_utils, + modelloader, + shared, +) + +logger = logging.getLogger(__name__) -# codeformer people made a choice to include modified basicsr library to their project which makes -# it utterly impossible to use it alongside with other libraries that also use basicsr, like GFPGAN. -# I am making a choice to include some files from codeformer to work around this issue. -model_dir = "Codeformer" -model_path = os.path.join(models_path, model_dir) model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' +model_download_name = 'codeformer-v0.1.0.pth' -codeformer = None +# used by e.g. postprocessing_codeformer.py +codeformer: face_restoration.FaceRestoration | None = None -def setup_model(dirname): - os.makedirs(model_path, exist_ok=True) +class FaceRestorerCodeFormer(face_restoration_utils.CommonFaceRestoration): + def name(self): + return "CodeFormer" - path = modules.paths.paths.get("CodeFormer", None) - if path is None: - return + def load_net(self) -> torch.Module: + for model_path in modelloader.load_models( + model_path=self.model_path, + model_url=model_url, + command_path=self.model_path, + download_name=model_download_name, + ext_filter=['.pth'], + ): + return modelloader.load_spandrel_model( + model_path, + device=devices.device_codeformer, + expected_architecture='CodeFormer', + ).model + raise ValueError("No codeformer model found") + def get_device(self): + return devices.device_codeformer + + def restore(self, np_image, w: float | None = None): + if w is None: + w = getattr(shared.opts, "code_former_weight", 0.5) + + def restore_face(cropped_face_t): + assert self.net is not None + return self.net(cropped_face_t, w=w, adain=True)[0] + + return self.restore_with_helper(np_image, restore_face) + + +def setup_model(dirname: str) -> None: + global codeformer try: - from torchvision.transforms.functional import normalize - from modules.codeformer.codeformer_arch import CodeFormer - from basicsr.utils import img2tensor, tensor2img - from facelib.utils.face_restoration_helper import FaceRestoreHelper - from facelib.detection.retinaface import retinaface - - net_class = CodeFormer - - class FaceRestorerCodeFormer(modules.face_restoration.FaceRestoration): - def name(self): - return "CodeFormer" - - def __init__(self, dirname): - self.net = None - self.face_helper = None - self.cmd_dir = dirname - - def create_models(self): - - if self.net is not None and self.face_helper is not None: - self.net.to(devices.device_codeformer) - return self.net, self.face_helper - model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir, download_name='codeformer-v0.1.0.pth', ext_filter=['.pth']) - if len(model_paths) != 0: - ckpt_path = model_paths[0] - else: - print("Unable to load codeformer model.") - return None, None - net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(devices.device_codeformer) - checkpoint = torch.load(ckpt_path)['params_ema'] - net.load_state_dict(checkpoint) - net.eval() - - if hasattr(retinaface, 'device'): - retinaface.device = devices.device_codeformer - face_helper = FaceRestoreHelper(1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', use_parse=True, device=devices.device_codeformer) - - self.net = net - self.face_helper = face_helper - - return net, face_helper - - def send_model_to(self, device): - self.net.to(device) - self.face_helper.face_det.to(device) - self.face_helper.face_parse.to(device) - - def restore(self, np_image, w=None): - np_image = np_image[:, :, ::-1] - - original_resolution = np_image.shape[0:2] - - self.create_models() - if self.net is None or self.face_helper is None: - return np_image - - self.send_model_to(devices.device_codeformer) - - self.face_helper.clean_all() - self.face_helper.read_image(np_image) - self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5) - self.face_helper.align_warp_face() - - for cropped_face in self.face_helper.cropped_faces: - cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) - normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) - cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer) - - try: - with torch.no_grad(): - output = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0] - restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) - del output - devices.torch_gc() - except Exception: - errors.report('Failed inference for CodeFormer', exc_info=True) - restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) - - restored_face = restored_face.astype('uint8') - self.face_helper.add_restored_face(restored_face) - - self.face_helper.get_inverse_affine(None) - - restored_img = self.face_helper.paste_faces_to_input_image() - restored_img = restored_img[:, :, ::-1] - - if original_resolution != restored_img.shape[0:2]: - restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR) - - self.face_helper.clean_all() - - if shared.opts.face_restoration_unload: - self.send_model_to(devices.cpu) - - return restored_img - - global codeformer codeformer = FaceRestorerCodeFormer(dirname) shared.face_restorers.append(codeformer) - except Exception: errors.report("Error setting up CodeFormer", exc_info=True) - - # sys.path = stored_sys_path diff --git a/modules/devices.py b/modules/devices.py index 1d4eb5635..ff279ac50 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -4,10 +4,18 @@ from functools import lru_cache import torch from modules import errors, shared +from modules import torch_utils if sys.platform == "darwin": from modules import mac_specific +if shared.cmd_opts.use_ipex: + from modules import xpu_specific + + +def has_xpu() -> bool: + return shared.cmd_opts.use_ipex and xpu_specific.has_xpu + def has_mps() -> bool: if sys.platform != "darwin": @@ -16,6 +24,23 @@ def has_mps() -> bool: return mac_specific.has_mps +def cuda_no_autocast(device_id=None) -> bool: + if device_id is None: + device_id = get_cuda_device_id() + return ( + torch.cuda.get_device_capability(device_id) == (7, 5) + and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16") + ) + + +def get_cuda_device_id(): + return ( + int(shared.cmd_opts.device_id) + if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() + else 0 + ) or torch.cuda.current_device() + + def get_cuda_device_string(): if shared.cmd_opts.device_id is not None: return f"cuda:{shared.cmd_opts.device_id}" @@ -30,6 +55,9 @@ def get_optimal_device_name(): if has_mps(): return "mps" + if has_xpu(): + return xpu_specific.get_xpu_device_string() + return "cpu" @@ -38,7 +66,7 @@ def get_optimal_device(): def get_device_for(task): - if task in shared.cmd_opts.use_cpu: + if task in shared.cmd_opts.use_cpu or "all" in shared.cmd_opts.use_cpu: return cpu return get_optimal_device() @@ -54,14 +82,16 @@ def torch_gc(): if has_mps(): mac_specific.torch_mps_gc() + if has_xpu(): + xpu_specific.torch_xpu_gc() + def enable_tf32(): if torch.cuda.is_available(): # enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't # see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407 - device_id = (int(shared.cmd_opts.device_id) if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() else 0) or torch.cuda.current_device() - if torch.cuda.get_device_capability(device_id) == (7, 5) and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16"): + if cuda_no_autocast(): torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True @@ -71,6 +101,7 @@ def enable_tf32(): errors.run(enable_tf32, "Enabling TF32") cpu: torch.device = torch.device("cpu") +fp8: bool = False device: torch.device = None device_interrogate: torch.device = None device_gfpgan: torch.device = None @@ -91,12 +122,51 @@ def cond_cast_float(input): nv_rng = None +patch_module_list = [ + torch.nn.Linear, + torch.nn.Conv2d, + torch.nn.MultiheadAttention, + torch.nn.GroupNorm, + torch.nn.LayerNorm, +] + + +def manual_cast_forward(self, *args, **kwargs): + org_dtype = torch_utils.get_param(self).dtype + self.to(dtype) + args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] + kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} + result = self.org_forward(*args, **kwargs) + self.to(org_dtype) + return result + + +@contextlib.contextmanager +def manual_cast(): + for module_type in patch_module_list: + org_forward = module_type.forward + module_type.forward = manual_cast_forward + module_type.org_forward = org_forward + try: + yield None + finally: + for module_type in patch_module_list: + module_type.forward = module_type.org_forward def autocast(disable=False): if disable: return contextlib.nullcontext() + if fp8 and device==cpu: + return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True) + + if fp8 and (dtype == torch.float32 or shared.cmd_opts.precision == "full" or cuda_no_autocast()): + return manual_cast() + + if has_mps() and shared.cmd_opts.precision != "full": + return manual_cast() + if dtype == torch.float32 or shared.cmd_opts.precision == "full": return contextlib.nullcontext() diff --git a/modules/errors.py b/modules/errors.py index 8c339464d..48aa13a17 100644 --- a/modules/errors.py +++ b/modules/errors.py @@ -6,6 +6,21 @@ import traceback exception_records = [] +def format_traceback(tb): + return [[f"{x.filename}, line {x.lineno}, {x.name}", x.line] for x in traceback.extract_tb(tb)] + + +def format_exception(e, tb): + return {"exception": str(e), "traceback": format_traceback(tb)} + + +def get_exceptions(): + try: + return list(reversed(exception_records)) + except Exception as e: + return str(e) + + def record_exception(): _, e, tb = sys.exc_info() if e is None: @@ -14,8 +29,7 @@ def record_exception(): if exception_records and exception_records[-1] == e: return - from modules import sysinfo - exception_records.append(sysinfo.format_exception(e, tb)) + exception_records.append(format_exception(e, tb)) if len(exception_records) > 5: exception_records.pop(0) @@ -93,8 +107,8 @@ def check_versions(): import torch import gradio - expected_torch_version = "2.0.0" - expected_xformers_version = "0.0.20" + expected_torch_version = "2.1.2" + expected_xformers_version = "0.0.23.post1" expected_gradio_version = "3.41.2" if version.parse(torch.__version__) < version.parse(expected_torch_version): diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index 02a1727d2..70041ab02 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -1,121 +1,7 @@ -import sys - -import numpy as np -import torch -from PIL import Image - -import modules.esrgan_model_arch as arch -from modules import modelloader, images, devices +from modules import modelloader, devices, errors from modules.shared import opts from modules.upscaler import Upscaler, UpscalerData - - -def mod2normal(state_dict): - # this code is copied from https://github.com/victorca25/iNNfer - if 'conv_first.weight' in state_dict: - crt_net = {} - items = list(state_dict) - - crt_net['model.0.weight'] = state_dict['conv_first.weight'] - crt_net['model.0.bias'] = state_dict['conv_first.bias'] - - for k in items.copy(): - if 'RDB' in k: - ori_k = k.replace('RRDB_trunk.', 'model.1.sub.') - if '.weight' in k: - ori_k = ori_k.replace('.weight', '.0.weight') - elif '.bias' in k: - ori_k = ori_k.replace('.bias', '.0.bias') - crt_net[ori_k] = state_dict[k] - items.remove(k) - - crt_net['model.1.sub.23.weight'] = state_dict['trunk_conv.weight'] - crt_net['model.1.sub.23.bias'] = state_dict['trunk_conv.bias'] - crt_net['model.3.weight'] = state_dict['upconv1.weight'] - crt_net['model.3.bias'] = state_dict['upconv1.bias'] - crt_net['model.6.weight'] = state_dict['upconv2.weight'] - crt_net['model.6.bias'] = state_dict['upconv2.bias'] - crt_net['model.8.weight'] = state_dict['HRconv.weight'] - crt_net['model.8.bias'] = state_dict['HRconv.bias'] - crt_net['model.10.weight'] = state_dict['conv_last.weight'] - crt_net['model.10.bias'] = state_dict['conv_last.bias'] - state_dict = crt_net - return state_dict - - -def resrgan2normal(state_dict, nb=23): - # this code is copied from https://github.com/victorca25/iNNfer - if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict: - re8x = 0 - crt_net = {} - items = list(state_dict) - - crt_net['model.0.weight'] = state_dict['conv_first.weight'] - crt_net['model.0.bias'] = state_dict['conv_first.bias'] - - for k in items.copy(): - if "rdb" in k: - ori_k = k.replace('body.', 'model.1.sub.') - ori_k = ori_k.replace('.rdb', '.RDB') - if '.weight' in k: - ori_k = ori_k.replace('.weight', '.0.weight') - elif '.bias' in k: - ori_k = ori_k.replace('.bias', '.0.bias') - crt_net[ori_k] = state_dict[k] - items.remove(k) - - crt_net[f'model.1.sub.{nb}.weight'] = state_dict['conv_body.weight'] - crt_net[f'model.1.sub.{nb}.bias'] = state_dict['conv_body.bias'] - crt_net['model.3.weight'] = state_dict['conv_up1.weight'] - crt_net['model.3.bias'] = state_dict['conv_up1.bias'] - crt_net['model.6.weight'] = state_dict['conv_up2.weight'] - crt_net['model.6.bias'] = state_dict['conv_up2.bias'] - - if 'conv_up3.weight' in state_dict: - # modification supporting: https://github.com/ai-forever/Real-ESRGAN/blob/main/RealESRGAN/rrdbnet_arch.py - re8x = 3 - crt_net['model.9.weight'] = state_dict['conv_up3.weight'] - crt_net['model.9.bias'] = state_dict['conv_up3.bias'] - - crt_net[f'model.{8+re8x}.weight'] = state_dict['conv_hr.weight'] - crt_net[f'model.{8+re8x}.bias'] = state_dict['conv_hr.bias'] - crt_net[f'model.{10+re8x}.weight'] = state_dict['conv_last.weight'] - crt_net[f'model.{10+re8x}.bias'] = state_dict['conv_last.bias'] - - state_dict = crt_net - return state_dict - - -def infer_params(state_dict): - # this code is copied from https://github.com/victorca25/iNNfer - scale2x = 0 - scalemin = 6 - n_uplayer = 0 - plus = False - - for block in list(state_dict): - parts = block.split(".") - n_parts = len(parts) - if n_parts == 5 and parts[2] == "sub": - nb = int(parts[3]) - elif n_parts == 3: - part_num = int(parts[1]) - if (part_num > scalemin - and parts[0] == "model" - and parts[2] == "weight"): - scale2x += 1 - if part_num > n_uplayer: - n_uplayer = part_num - out_nc = state_dict[block].shape[0] - if not plus and "conv1x1" in block: - plus = True - - nf = state_dict["model.0.weight"].shape[0] - in_nc = state_dict["model.0.weight"].shape[1] - out_nc = out_nc - scale = 2 ** scale2x - - return in_nc, out_nc, nf, nb, plus, scale +from modules.upscaler_utils import upscale_with_model class UpscalerESRGAN(Upscaler): @@ -143,12 +29,11 @@ class UpscalerESRGAN(Upscaler): def do_upscale(self, img, selected_model): try: model = self.load_model(selected_model) - except Exception as e: - print(f"Unable to load ESRGAN model {selected_model}: {e}", file=sys.stderr) + except Exception: + errors.report(f"Unable to load ESRGAN model {selected_model}", exc_info=True) return img model.to(devices.device_esrgan) - img = esrgan_upscale(model, img) - return img + return esrgan_upscale(model, img) def load_model(self, path: str): if path.startswith("http"): @@ -161,69 +46,17 @@ class UpscalerESRGAN(Upscaler): else: filename = path - state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None) - - if "params_ema" in state_dict: - state_dict = state_dict["params_ema"] - elif "params" in state_dict: - state_dict = state_dict["params"] - num_conv = 16 if "realesr-animevideov3" in filename else 32 - model = arch.SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=num_conv, upscale=4, act_type='prelu') - model.load_state_dict(state_dict) - model.eval() - return model - - if "body.0.rdb1.conv1.weight" in state_dict and "conv_first.weight" in state_dict: - nb = 6 if "RealESRGAN_x4plus_anime_6B" in filename else 23 - state_dict = resrgan2normal(state_dict, nb) - elif "conv_first.weight" in state_dict: - state_dict = mod2normal(state_dict) - elif "model.0.weight" not in state_dict: - raise Exception("The file is not a recognized ESRGAN model.") - - in_nc, out_nc, nf, nb, plus, mscale = infer_params(state_dict) - - model = arch.RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb, upscale=mscale, plus=plus) - model.load_state_dict(state_dict) - model.eval() - - return model - - -def upscale_without_tiling(model, img): - img = np.array(img) - img = img[:, :, ::-1] - img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 - img = torch.from_numpy(img).float() - img = img.unsqueeze(0).to(devices.device_esrgan) - with torch.no_grad(): - output = model(img) - output = output.squeeze().float().cpu().clamp_(0, 1).numpy() - output = 255. * np.moveaxis(output, 0, 2) - output = output.astype(np.uint8) - output = output[:, :, ::-1] - return Image.fromarray(output, 'RGB') + return modelloader.load_spandrel_model( + filename, + device=('cpu' if devices.device_esrgan.type == 'mps' else None), + expected_architecture='ESRGAN', + ) def esrgan_upscale(model, img): - if opts.ESRGAN_tile == 0: - return upscale_without_tiling(model, img) - - grid = images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap) - newtiles = [] - scale_factor = 1 - - for y, h, row in grid.tiles: - newrow = [] - for tiledata in row: - x, w, tile = tiledata - - output = upscale_without_tiling(model, tile) - scale_factor = output.width // tile.width - - newrow.append([x * scale_factor, w * scale_factor, output]) - newtiles.append([y * scale_factor, h * scale_factor, newrow]) - - newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor) - output = images.combine_grid(newgrid) - return output + return upscale_with_model( + model, + img, + tile_size=opts.ESRGAN_tile, + tile_overlap=opts.ESRGAN_tile_overlap, + ) diff --git a/modules/esrgan_model_arch.py b/modules/esrgan_model_arch.py deleted file mode 100644 index 2b9888baf..000000000 --- a/modules/esrgan_model_arch.py +++ /dev/null @@ -1,465 +0,0 @@ -# this file is adapted from https://github.com/victorca25/iNNfer - -from collections import OrderedDict -import math -import torch -import torch.nn as nn -import torch.nn.functional as F - - -#################### -# RRDBNet Generator -#################### - -class RRDBNet(nn.Module): - def __init__(self, in_nc, out_nc, nf, nb, nr=3, gc=32, upscale=4, norm_type=None, - act_type='leakyrelu', mode='CNA', upsample_mode='upconv', convtype='Conv2D', - finalact=None, gaussian_noise=False, plus=False): - super(RRDBNet, self).__init__() - n_upscale = int(math.log(upscale, 2)) - if upscale == 3: - n_upscale = 1 - - self.resrgan_scale = 0 - if in_nc % 16 == 0: - self.resrgan_scale = 1 - elif in_nc != 4 and in_nc % 4 == 0: - self.resrgan_scale = 2 - - fea_conv = conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, convtype=convtype) - rb_blocks = [RRDB(nf, nr, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero', - norm_type=norm_type, act_type=act_type, mode='CNA', convtype=convtype, - gaussian_noise=gaussian_noise, plus=plus) for _ in range(nb)] - LR_conv = conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode, convtype=convtype) - - if upsample_mode == 'upconv': - upsample_block = upconv_block - elif upsample_mode == 'pixelshuffle': - upsample_block = pixelshuffle_block - else: - raise NotImplementedError(f'upsample mode [{upsample_mode}] is not found') - if upscale == 3: - upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype) - else: - upsampler = [upsample_block(nf, nf, act_type=act_type, convtype=convtype) for _ in range(n_upscale)] - HR_conv0 = conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type, convtype=convtype) - HR_conv1 = conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None, convtype=convtype) - - outact = act(finalact) if finalact else None - - self.model = sequential(fea_conv, ShortcutBlock(sequential(*rb_blocks, LR_conv)), - *upsampler, HR_conv0, HR_conv1, outact) - - def forward(self, x, outm=None): - if self.resrgan_scale == 1: - feat = pixel_unshuffle(x, scale=4) - elif self.resrgan_scale == 2: - feat = pixel_unshuffle(x, scale=2) - else: - feat = x - - return self.model(feat) - - -class RRDB(nn.Module): - """ - Residual in Residual Dense Block - (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks) - """ - - def __init__(self, nf, nr=3, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero', - norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D', - spectral_norm=False, gaussian_noise=False, plus=False): - super(RRDB, self).__init__() - # This is for backwards compatibility with existing models - if nr == 3: - self.RDB1 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type, - norm_type, act_type, mode, convtype, spectral_norm=spectral_norm, - gaussian_noise=gaussian_noise, plus=plus) - self.RDB2 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type, - norm_type, act_type, mode, convtype, spectral_norm=spectral_norm, - gaussian_noise=gaussian_noise, plus=plus) - self.RDB3 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type, - norm_type, act_type, mode, convtype, spectral_norm=spectral_norm, - gaussian_noise=gaussian_noise, plus=plus) - else: - RDB_list = [ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type, - norm_type, act_type, mode, convtype, spectral_norm=spectral_norm, - gaussian_noise=gaussian_noise, plus=plus) for _ in range(nr)] - self.RDBs = nn.Sequential(*RDB_list) - - def forward(self, x): - if hasattr(self, 'RDB1'): - out = self.RDB1(x) - out = self.RDB2(out) - out = self.RDB3(out) - else: - out = self.RDBs(x) - return out * 0.2 + x - - -class ResidualDenseBlock_5C(nn.Module): - """ - Residual Dense Block - The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18) - Modified options that can be used: - - "Partial Convolution based Padding" arXiv:1811.11718 - - "Spectral normalization" arXiv:1802.05957 - - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C. - {Rakotonirina} and A. {Rasoanaivo} - """ - - def __init__(self, nf=64, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero', - norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D', - spectral_norm=False, gaussian_noise=False, plus=False): - super(ResidualDenseBlock_5C, self).__init__() - - self.noise = GaussianNoise() if gaussian_noise else None - self.conv1x1 = conv1x1(nf, gc) if plus else None - - self.conv1 = conv_block(nf, gc, kernel_size, stride, bias=bias, pad_type=pad_type, - norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype, - spectral_norm=spectral_norm) - self.conv2 = conv_block(nf+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, - norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype, - spectral_norm=spectral_norm) - self.conv3 = conv_block(nf+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, - norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype, - spectral_norm=spectral_norm) - self.conv4 = conv_block(nf+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, - norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype, - spectral_norm=spectral_norm) - if mode == 'CNA': - last_act = None - else: - last_act = act_type - self.conv5 = conv_block(nf+4*gc, nf, 3, stride, bias=bias, pad_type=pad_type, - norm_type=norm_type, act_type=last_act, mode=mode, convtype=convtype, - spectral_norm=spectral_norm) - - def forward(self, x): - x1 = self.conv1(x) - x2 = self.conv2(torch.cat((x, x1), 1)) - if self.conv1x1: - x2 = x2 + self.conv1x1(x) - x3 = self.conv3(torch.cat((x, x1, x2), 1)) - x4 = self.conv4(torch.cat((x, x1, x2, x3), 1)) - if self.conv1x1: - x4 = x4 + x2 - x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) - if self.noise: - return self.noise(x5.mul(0.2) + x) - else: - return x5 * 0.2 + x - - -#################### -# ESRGANplus -#################### - -class GaussianNoise(nn.Module): - def __init__(self, sigma=0.1, is_relative_detach=False): - super().__init__() - self.sigma = sigma - self.is_relative_detach = is_relative_detach - self.noise = torch.tensor(0, dtype=torch.float) - - def forward(self, x): - if self.training and self.sigma != 0: - self.noise = self.noise.to(x.device) - scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x - sampled_noise = self.noise.repeat(*x.size()).normal_() * scale - x = x + sampled_noise - return x - -def conv1x1(in_planes, out_planes, stride=1): - return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) - - -#################### -# SRVGGNetCompact -#################### - -class SRVGGNetCompact(nn.Module): - """A compact VGG-style network structure for super-resolution. - This class is copied from https://github.com/xinntao/Real-ESRGAN - """ - - def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'): - super(SRVGGNetCompact, self).__init__() - self.num_in_ch = num_in_ch - self.num_out_ch = num_out_ch - self.num_feat = num_feat - self.num_conv = num_conv - self.upscale = upscale - self.act_type = act_type - - self.body = nn.ModuleList() - # the first conv - self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)) - # the first activation - if act_type == 'relu': - activation = nn.ReLU(inplace=True) - elif act_type == 'prelu': - activation = nn.PReLU(num_parameters=num_feat) - elif act_type == 'leakyrelu': - activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) - self.body.append(activation) - - # the body structure - for _ in range(num_conv): - self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1)) - # activation - if act_type == 'relu': - activation = nn.ReLU(inplace=True) - elif act_type == 'prelu': - activation = nn.PReLU(num_parameters=num_feat) - elif act_type == 'leakyrelu': - activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) - self.body.append(activation) - - # the last conv - self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1)) - # upsample - self.upsampler = nn.PixelShuffle(upscale) - - def forward(self, x): - out = x - for i in range(0, len(self.body)): - out = self.body[i](out) - - out = self.upsampler(out) - # add the nearest upsampled image, so that the network learns the residual - base = F.interpolate(x, scale_factor=self.upscale, mode='nearest') - out += base - return out - - -#################### -# Upsampler -#################### - -class Upsample(nn.Module): - r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data. - The input data is assumed to be of the form - `minibatch x channels x [optional depth] x [optional height] x width`. - """ - - def __init__(self, size=None, scale_factor=None, mode="nearest", align_corners=None): - super(Upsample, self).__init__() - if isinstance(scale_factor, tuple): - self.scale_factor = tuple(float(factor) for factor in scale_factor) - else: - self.scale_factor = float(scale_factor) if scale_factor else None - self.mode = mode - self.size = size - self.align_corners = align_corners - - def forward(self, x): - return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners) - - def extra_repr(self): - if self.scale_factor is not None: - info = f'scale_factor={self.scale_factor}' - else: - info = f'size={self.size}' - info += f', mode={self.mode}' - return info - - -def pixel_unshuffle(x, scale): - """ Pixel unshuffle. - Args: - x (Tensor): Input feature with shape (b, c, hh, hw). - scale (int): Downsample ratio. - Returns: - Tensor: the pixel unshuffled feature. - """ - b, c, hh, hw = x.size() - out_channel = c * (scale**2) - assert hh % scale == 0 and hw % scale == 0 - h = hh // scale - w = hw // scale - x_view = x.view(b, c, h, scale, w, scale) - return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w) - - -def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True, - pad_type='zero', norm_type=None, act_type='relu', convtype='Conv2D'): - """ - Pixel shuffle layer - (Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional - Neural Network, CVPR17) - """ - conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias, - pad_type=pad_type, norm_type=None, act_type=None, convtype=convtype) - pixel_shuffle = nn.PixelShuffle(upscale_factor) - - n = norm(norm_type, out_nc) if norm_type else None - a = act(act_type) if act_type else None - return sequential(conv, pixel_shuffle, n, a) - - -def upconv_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True, - pad_type='zero', norm_type=None, act_type='relu', mode='nearest', convtype='Conv2D'): - """ Upconv layer """ - upscale_factor = (1, upscale_factor, upscale_factor) if convtype == 'Conv3D' else upscale_factor - upsample = Upsample(scale_factor=upscale_factor, mode=mode) - conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias, - pad_type=pad_type, norm_type=norm_type, act_type=act_type, convtype=convtype) - return sequential(upsample, conv) - - - - - - - - -#################### -# Basic blocks -#################### - - -def make_layer(basic_block, num_basic_block, **kwarg): - """Make layers by stacking the same blocks. - Args: - basic_block (nn.module): nn.module class for basic block. (block) - num_basic_block (int): number of blocks. (n_layers) - Returns: - nn.Sequential: Stacked blocks in nn.Sequential. - """ - layers = [] - for _ in range(num_basic_block): - layers.append(basic_block(**kwarg)) - return nn.Sequential(*layers) - - -def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0): - """ activation helper """ - act_type = act_type.lower() - if act_type == 'relu': - layer = nn.ReLU(inplace) - elif act_type in ('leakyrelu', 'lrelu'): - layer = nn.LeakyReLU(neg_slope, inplace) - elif act_type == 'prelu': - layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope) - elif act_type == 'tanh': # [-1, 1] range output - layer = nn.Tanh() - elif act_type == 'sigmoid': # [0, 1] range output - layer = nn.Sigmoid() - else: - raise NotImplementedError(f'activation layer [{act_type}] is not found') - return layer - - -class Identity(nn.Module): - def __init__(self, *kwargs): - super(Identity, self).__init__() - - def forward(self, x, *kwargs): - return x - - -def norm(norm_type, nc): - """ Return a normalization layer """ - norm_type = norm_type.lower() - if norm_type == 'batch': - layer = nn.BatchNorm2d(nc, affine=True) - elif norm_type == 'instance': - layer = nn.InstanceNorm2d(nc, affine=False) - elif norm_type == 'none': - def norm_layer(x): return Identity() - else: - raise NotImplementedError(f'normalization layer [{norm_type}] is not found') - return layer - - -def pad(pad_type, padding): - """ padding layer helper """ - pad_type = pad_type.lower() - if padding == 0: - return None - if pad_type == 'reflect': - layer = nn.ReflectionPad2d(padding) - elif pad_type == 'replicate': - layer = nn.ReplicationPad2d(padding) - elif pad_type == 'zero': - layer = nn.ZeroPad2d(padding) - else: - raise NotImplementedError(f'padding layer [{pad_type}] is not implemented') - return layer - - -def get_valid_padding(kernel_size, dilation): - kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1) - padding = (kernel_size - 1) // 2 - return padding - - -class ShortcutBlock(nn.Module): - """ Elementwise sum the output of a submodule to its input """ - def __init__(self, submodule): - super(ShortcutBlock, self).__init__() - self.sub = submodule - - def forward(self, x): - output = x + self.sub(x) - return output - - def __repr__(self): - return 'Identity + \n|' + self.sub.__repr__().replace('\n', '\n|') - - -def sequential(*args): - """ Flatten Sequential. It unwraps nn.Sequential. """ - if len(args) == 1: - if isinstance(args[0], OrderedDict): - raise NotImplementedError('sequential does not support OrderedDict input.') - return args[0] # No sequential is needed. - modules = [] - for module in args: - if isinstance(module, nn.Sequential): - for submodule in module.children(): - modules.append(submodule) - elif isinstance(module, nn.Module): - modules.append(module) - return nn.Sequential(*modules) - - -def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True, - pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D', - spectral_norm=False): - """ Conv layer with padding, normalization, activation """ - assert mode in ['CNA', 'NAC', 'CNAC'], f'Wrong conv mode [{mode}]' - padding = get_valid_padding(kernel_size, dilation) - p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None - padding = padding if pad_type == 'zero' else 0 - - if convtype=='PartialConv2D': - from torchvision.ops import PartialConv2d # this is definitely not going to work, but PartialConv2d doesn't work anyway and this shuts up static analyzer - c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, - dilation=dilation, bias=bias, groups=groups) - elif convtype=='DeformConv2D': - from torchvision.ops import DeformConv2d # not tested - c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, - dilation=dilation, bias=bias, groups=groups) - elif convtype=='Conv3D': - c = nn.Conv3d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, - dilation=dilation, bias=bias, groups=groups) - else: - c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, - dilation=dilation, bias=bias, groups=groups) - - if spectral_norm: - c = nn.utils.spectral_norm(c) - - a = act(act_type) if act_type else None - if 'CNA' in mode: - n = norm(norm_type, out_nc) if norm_type else None - return sequential(p, c, n, a) - elif mode == 'NAC': - if norm_type is None and act_type is not None: - a = act(act_type, inplace=False) - n = norm(norm_type, in_nc) if norm_type else None - return sequential(n, a, p, c) diff --git a/modules/extensions.py b/modules/extensions.py index bf9a1878f..1899cd529 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -1,11 +1,14 @@ +from __future__ import annotations + +import configparser import os import threading +import re from modules import shared, errors, cache, scripts from modules.gitpython_hack import Repo from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401 -extensions = [] os.makedirs(extensions_dir, exist_ok=True) @@ -19,11 +22,55 @@ def active(): return [x for x in extensions if x.enabled] +class ExtensionMetadata: + filename = "metadata.ini" + config: configparser.ConfigParser + canonical_name: str + requires: list + + def __init__(self, path, canonical_name): + self.config = configparser.ConfigParser() + + filepath = os.path.join(path, self.filename) + if os.path.isfile(filepath): + try: + self.config.read(filepath) + except Exception: + errors.report(f"Error reading {self.filename} for extension {canonical_name}.", exc_info=True) + + self.canonical_name = self.config.get("Extension", "Name", fallback=canonical_name) + self.canonical_name = canonical_name.lower().strip() + + self.requires = self.get_script_requirements("Requires", "Extension") + + def get_script_requirements(self, field, section, extra_section=None): + """reads a list of requirements from the config; field is the name of the field in the ini file, + like Requires or Before, and section is the name of the [section] in the ini file; additionally, + reads more requirements from [extra_section] if specified.""" + + x = self.config.get(section, field, fallback='') + + if extra_section: + x = x + ', ' + self.config.get(extra_section, field, fallback='') + + return self.parse_list(x.lower()) + + def parse_list(self, text): + """converts a line from config ("ext1 ext2, ext3 ") into a python list (["ext1", "ext2", "ext3"])""" + + if not text: + return [] + + # both "," and " " are accepted as separator + return [x for x in re.split(r"[,\s]+", text.strip()) if x] + + class Extension: lock = threading.Lock() cached_fields = ['remote', 'commit_date', 'branch', 'commit_hash', 'version'] + metadata: ExtensionMetadata - def __init__(self, name, path, enabled=True, is_builtin=False): + def __init__(self, name, path, enabled=True, is_builtin=False, metadata=None): self.name = name self.path = path self.enabled = enabled @@ -36,6 +83,8 @@ class Extension: self.branch = None self.remote = None self.have_info_from_repo = False + self.metadata = metadata if metadata else ExtensionMetadata(self.path, name.lower()) + self.canonical_name = metadata.canonical_name def to_dict(self): return {x: getattr(self, x) for x in self.cached_fields} @@ -56,6 +105,7 @@ class Extension: self.do_read_info_from_repo() return self.to_dict() + try: d = cache.cached_data_for_file('extensions-git', self.name, os.path.join(self.path, ".git"), read_from_repo) self.from_dict(d) @@ -136,9 +186,6 @@ class Extension: def list_extensions(): extensions.clear() - if not os.path.isdir(extensions_dir): - return - if shared.cmd_opts.disable_all_extensions: print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***") elif shared.opts.disable_all_extensions == "all": @@ -148,18 +195,43 @@ def list_extensions(): elif shared.opts.disable_all_extensions == "extra": print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***") - extension_paths = [] - for dirname in [extensions_dir, extensions_builtin_dir]: + loaded_extensions = {} + + # scan through extensions directory and load metadata + for dirname in [extensions_builtin_dir, extensions_dir]: if not os.path.isdir(dirname): - return + continue for extension_dirname in sorted(os.listdir(dirname)): path = os.path.join(dirname, extension_dirname) if not os.path.isdir(path): continue - extension_paths.append((extension_dirname, path, dirname == extensions_builtin_dir)) + canonical_name = extension_dirname + metadata = ExtensionMetadata(path, canonical_name) - for dirname, path, is_builtin in extension_paths: - extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin) - extensions.append(extension) + # check for duplicated canonical names + already_loaded_extension = loaded_extensions.get(metadata.canonical_name) + if already_loaded_extension is not None: + errors.report(f'Duplicate canonical name "{canonical_name}" found in extensions "{extension_dirname}" and "{already_loaded_extension.name}". Former will be discarded.', exc_info=False) + continue + + is_builtin = dirname == extensions_builtin_dir + extension = Extension(name=extension_dirname, path=path, enabled=extension_dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin, metadata=metadata) + extensions.append(extension) + loaded_extensions[canonical_name] = extension + + # check for requirements + for extension in extensions: + for req in extension.metadata.requires: + required_extension = loaded_extensions.get(req) + if required_extension is None: + errors.report(f'Extension "{extension.name}" requires "{req}" which is not installed.', exc_info=False) + continue + + if not extension.enabled: + errors.report(f'Extension "{extension.name}" requires "{required_extension.name}" which is disabled.', exc_info=False) + continue + + +extensions: list[Extension] = [] diff --git a/modules/face_restoration_utils.py b/modules/face_restoration_utils.py new file mode 100644 index 000000000..1cbac2364 --- /dev/null +++ b/modules/face_restoration_utils.py @@ -0,0 +1,180 @@ +from __future__ import annotations + +import logging +import os +from functools import cached_property +from typing import TYPE_CHECKING, Callable + +import cv2 +import numpy as np +import torch + +from modules import devices, errors, face_restoration, shared + +if TYPE_CHECKING: + from facexlib.utils.face_restoration_helper import FaceRestoreHelper + +logger = logging.getLogger(__name__) + + +def bgr_image_to_rgb_tensor(img: np.ndarray) -> torch.Tensor: + """Convert a BGR NumPy image in [0..1] range to a PyTorch RGB float32 tensor.""" + assert img.shape[2] == 3, "image must be RGB" + if img.dtype == "float64": + img = img.astype("float32") + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + return torch.from_numpy(img.transpose(2, 0, 1)).float() + + +def rgb_tensor_to_bgr_image(tensor: torch.Tensor, *, min_max=(0.0, 1.0)) -> np.ndarray: + """ + Convert a PyTorch RGB tensor in range `min_max` to a BGR NumPy image in [0..1] range. + """ + tensor = tensor.squeeze(0).float().detach().cpu().clamp_(*min_max) + tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) + assert tensor.dim() == 3, "tensor must be RGB" + img_np = tensor.numpy().transpose(1, 2, 0) + if img_np.shape[2] == 1: # gray image, no RGB/BGR required + return np.squeeze(img_np, axis=2) + return cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB) + + +def create_face_helper(device) -> FaceRestoreHelper: + from facexlib.detection import retinaface + from facexlib.utils.face_restoration_helper import FaceRestoreHelper + if hasattr(retinaface, 'device'): + retinaface.device = device + return FaceRestoreHelper( + upscale_factor=1, + face_size=512, + crop_ratio=(1, 1), + det_model='retinaface_resnet50', + save_ext='png', + use_parse=True, + device=device, + ) + + +def restore_with_face_helper( + np_image: np.ndarray, + face_helper: FaceRestoreHelper, + restore_face: Callable[[torch.Tensor], torch.Tensor], +) -> np.ndarray: + """ + Find faces in the image using face_helper, restore them using restore_face, and paste them back into the image. + + `restore_face` should take a cropped face image and return a restored face image. + """ + from torchvision.transforms.functional import normalize + np_image = np_image[:, :, ::-1] + original_resolution = np_image.shape[0:2] + + try: + logger.debug("Detecting faces...") + face_helper.clean_all() + face_helper.read_image(np_image) + face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5) + face_helper.align_warp_face() + logger.debug("Found %d faces, restoring", len(face_helper.cropped_faces)) + for cropped_face in face_helper.cropped_faces: + cropped_face_t = bgr_image_to_rgb_tensor(cropped_face / 255.0) + normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) + cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer) + + try: + with torch.no_grad(): + cropped_face_t = restore_face(cropped_face_t) + devices.torch_gc() + except Exception: + errors.report('Failed face-restoration inference', exc_info=True) + + restored_face = rgb_tensor_to_bgr_image(cropped_face_t, min_max=(-1, 1)) + restored_face = (restored_face * 255.0).astype('uint8') + face_helper.add_restored_face(restored_face) + + logger.debug("Merging restored faces into image") + face_helper.get_inverse_affine(None) + img = face_helper.paste_faces_to_input_image() + img = img[:, :, ::-1] + if original_resolution != img.shape[0:2]: + img = cv2.resize( + img, + (0, 0), + fx=original_resolution[1] / img.shape[1], + fy=original_resolution[0] / img.shape[0], + interpolation=cv2.INTER_LINEAR, + ) + logger.debug("Face restoration complete") + finally: + face_helper.clean_all() + return img + + +class CommonFaceRestoration(face_restoration.FaceRestoration): + net: torch.Module | None + model_url: str + model_download_name: str + + def __init__(self, model_path: str): + super().__init__() + self.net = None + self.model_path = model_path + os.makedirs(model_path, exist_ok=True) + + @cached_property + def face_helper(self) -> FaceRestoreHelper: + return create_face_helper(self.get_device()) + + def send_model_to(self, device): + if self.net: + logger.debug("Sending %s to %s", self.net, device) + self.net.to(device) + if self.face_helper: + logger.debug("Sending face helper to %s", device) + self.face_helper.face_det.to(device) + self.face_helper.face_parse.to(device) + + def get_device(self): + raise NotImplementedError("get_device must be implemented by subclasses") + + def load_net(self) -> torch.Module: + raise NotImplementedError("load_net must be implemented by subclasses") + + def restore_with_helper( + self, + np_image: np.ndarray, + restore_face: Callable[[torch.Tensor], torch.Tensor], + ) -> np.ndarray: + try: + if self.net is None: + self.net = self.load_net() + except Exception: + logger.warning("Unable to load face-restoration model", exc_info=True) + return np_image + + try: + self.send_model_to(self.get_device()) + return restore_with_face_helper(np_image, self.face_helper, restore_face) + finally: + if shared.opts.face_restoration_unload: + self.send_model_to(devices.cpu) + + +def patch_facexlib(dirname: str) -> None: + import facexlib.detection + import facexlib.parsing + + det_facex_load_file_from_url = facexlib.detection.load_file_from_url + par_facex_load_file_from_url = facexlib.parsing.load_file_from_url + + def update_kwargs(kwargs): + return dict(kwargs, save_dir=dirname, model_dir=None) + + def facex_load_file_from_url(**kwargs): + return det_facex_load_file_from_url(**update_kwargs(kwargs)) + + def facex_load_file_from_url2(**kwargs): + return par_facex_load_file_from_url(**update_kwargs(kwargs)) + + facexlib.detection.load_file_from_url = facex_load_file_from_url + facexlib.parsing.load_file_from_url = facex_load_file_from_url2 diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index 8e0f13bdc..445b04092 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -1,110 +1,71 @@ +from __future__ import annotations + +import logging import os -import facexlib -import gfpgan +import torch -import modules.face_restoration -from modules import paths, shared, devices, modelloader, errors +from modules import ( + devices, + errors, + face_restoration, + face_restoration_utils, + modelloader, + shared, +) -model_dir = "GFPGAN" -user_path = None -model_path = os.path.join(paths.models_path, model_dir) +logger = logging.getLogger(__name__) model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth" -have_gfpgan = False -loaded_gfpgan_model = None +model_download_name = "GFPGANv1.4.pth" +gfpgan_face_restorer: face_restoration.FaceRestoration | None = None -def gfpgann(): - global loaded_gfpgan_model - global model_path - if loaded_gfpgan_model is not None: - loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan) - return loaded_gfpgan_model +class FaceRestorerGFPGAN(face_restoration_utils.CommonFaceRestoration): + def name(self): + return "GFPGAN" - if gfpgan_constructor is None: - return None + def get_device(self): + return devices.device_gfpgan - models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN") - if len(models) == 1 and models[0].startswith("http"): - model_file = models[0] - elif len(models) != 0: - latest_file = max(models, key=os.path.getctime) - model_file = latest_file - else: - print("Unable to load gfpgan model!") - return None - if hasattr(facexlib.detection.retinaface, 'device'): - facexlib.detection.retinaface.device = devices.device_gfpgan - model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan) - loaded_gfpgan_model = model + def load_net(self) -> torch.Module: + for model_path in modelloader.load_models( + model_path=self.model_path, + model_url=model_url, + command_path=self.model_path, + download_name=model_download_name, + ext_filter=['.pth'], + ): + if 'GFPGAN' in os.path.basename(model_path): + model = modelloader.load_spandrel_model( + model_path, + device=self.get_device(), + expected_architecture='GFPGAN', + ).model + model.different_w = True # see https://github.com/chaiNNer-org/spandrel/pull/81 + return model + raise ValueError("No GFPGAN model found") - return model + def restore(self, np_image): + def restore_face(cropped_face_t): + assert self.net is not None + return self.net(cropped_face_t, return_rgb=False)[0] - -def send_model_to(model, device): - model.gfpgan.to(device) - model.face_helper.face_det.to(device) - model.face_helper.face_parse.to(device) + return self.restore_with_helper(np_image, restore_face) def gfpgan_fix_faces(np_image): - model = gfpgann() - if model is None: - return np_image - - send_model_to(model, devices.device_gfpgan) - - np_image_bgr = np_image[:, :, ::-1] - cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True) - np_image = gfpgan_output_bgr[:, :, ::-1] - - model.face_helper.clean_all() - - if shared.opts.face_restoration_unload: - send_model_to(model, devices.cpu) - + if gfpgan_face_restorer: + return gfpgan_face_restorer.restore(np_image) + logger.warning("GFPGAN face restorer not set up") return np_image -gfpgan_constructor = None +def setup_model(dirname: str) -> None: + global gfpgan_face_restorer - -def setup_model(dirname): try: - os.makedirs(model_path, exist_ok=True) - from gfpgan import GFPGANer - from facexlib import detection, parsing # noqa: F401 - global user_path - global have_gfpgan - global gfpgan_constructor - - load_file_from_url_orig = gfpgan.utils.load_file_from_url - facex_load_file_from_url_orig = facexlib.detection.load_file_from_url - facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url - - def my_load_file_from_url(**kwargs): - return load_file_from_url_orig(**dict(kwargs, model_dir=model_path)) - - def facex_load_file_from_url(**kwargs): - return facex_load_file_from_url_orig(**dict(kwargs, save_dir=model_path, model_dir=None)) - - def facex_load_file_from_url2(**kwargs): - return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=model_path, model_dir=None)) - - gfpgan.utils.load_file_from_url = my_load_file_from_url - facexlib.detection.load_file_from_url = facex_load_file_from_url - facexlib.parsing.load_file_from_url = facex_load_file_from_url2 - user_path = dirname - have_gfpgan = True - gfpgan_constructor = GFPGANer - - class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration): - def name(self): - return "GFPGAN" - - def restore(self, np_image): - return gfpgan_fix_faces(np_image) - - shared.face_restorers.append(FaceRestorerGFPGAN()) + face_restoration_utils.patch_facexlib(dirname) + gfpgan_face_restorer = FaceRestorerGFPGAN(model_path=dirname) + shared.face_restorers.append(gfpgan_face_restorer) except Exception: errors.report("Error setting up GFPGAN", exc_info=True) diff --git a/modules/gradio_extensons.py b/modules/gradio_extensons.py index e6b6835ad..7d88dc984 100644 --- a/modules/gradio_extensons.py +++ b/modules/gradio_extensons.py @@ -47,10 +47,20 @@ def Block_get_config(self): def BlockContext_init(self, *args, **kwargs): + if scripts.scripts_current is not None: + scripts.scripts_current.before_component(self, **kwargs) + + scripts.script_callbacks.before_component_callback(self, **kwargs) + res = original_BlockContext_init(self, *args, **kwargs) add_classes_to_gradio_component(self) + scripts.script_callbacks.after_component_callback(self, **kwargs) + + if scripts.scripts_current is not None: + scripts.scripts_current.after_component(self, **kwargs) + return res diff --git a/modules/hat_model.py b/modules/hat_model.py new file mode 100644 index 000000000..7f2abb416 --- /dev/null +++ b/modules/hat_model.py @@ -0,0 +1,43 @@ +import os +import sys + +from modules import modelloader, devices +from modules.shared import opts +from modules.upscaler import Upscaler, UpscalerData +from modules.upscaler_utils import upscale_with_model + + +class UpscalerHAT(Upscaler): + def __init__(self, dirname): + self.name = "HAT" + self.scalers = [] + self.user_path = dirname + super().__init__() + for file in self.find_models(ext_filter=[".pt", ".pth"]): + name = modelloader.friendly_name(file) + scale = 4 # TODO: scale might not be 4, but we can't know without loading the model + scaler_data = UpscalerData(name, file, upscaler=self, scale=scale) + self.scalers.append(scaler_data) + + def do_upscale(self, img, selected_model): + try: + model = self.load_model(selected_model) + except Exception as e: + print(f"Unable to load HAT model {selected_model}: {e}", file=sys.stderr) + return img + model.to(devices.device_esrgan) # TODO: should probably be device_hat + return upscale_with_model( + model, + img, + tile_size=opts.ESRGAN_tile, # TODO: should probably be HAT_tile + tile_overlap=opts.ESRGAN_tile_overlap, # TODO: should probably be HAT_tile_overlap + ) + + def load_model(self, path: str): + if not os.path.isfile(path): + raise FileNotFoundError(f"Model file {path} not found") + return modelloader.load_spandrel_model( + path, + device=devices.device_esrgan, # TODO: should probably be device_hat + expected_architecture='HAT', + ) diff --git a/modules/images.py b/modules/images.py index daf4eebe4..87a7bf221 100644 --- a/modules/images.py +++ b/modules/images.py @@ -61,12 +61,17 @@ def image_grid(imgs, batch_size=1, rows=None): return grid -Grid = namedtuple("Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"]) +class Grid(namedtuple("_Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"])): + @property + def tile_count(self) -> int: + """ + The total number of tiles in the grid. + """ + return sum(len(row[2]) for row in self.tiles) -def split_grid(image, tile_w=512, tile_h=512, overlap=64): - w = image.width - h = image.height +def split_grid(image: Image.Image, tile_w: int = 512, tile_h: int = 512, overlap: int = 64) -> Grid: + w, h = image.size non_overlap_width = tile_w - overlap non_overlap_height = tile_h - overlap @@ -791,3 +796,4 @@ def flatten(img, bgcolor): img = background return img.convert('RGB') + diff --git a/modules/img2img.py b/modules/img2img.py index 52cb577a6..e7e8e2510 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -7,7 +7,7 @@ from PIL import Image, ImageOps, ImageFilter, ImageEnhance, UnidentifiedImageErr import gradio as gr from modules import images as imgutil -from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters +from modules.infotext import create_override_settings_dict, parse_generation_parameters from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images from modules.shared import opts, state from modules.sd_models import get_closet_checkpoint_match @@ -44,12 +44,14 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal steps = p.steps override_settings = p.override_settings sd_model_checkpoint_override = get_closet_checkpoint_match(override_settings.get("sd_model_checkpoint", None)) + batch_results = None + discard_further_results = False for i, image in enumerate(images): state.job = f"{i+1} out of {len(images)}" if state.skipped: state.skipped = False - if state.interrupted: + if state.interrupted or state.stopping_generation: break try: @@ -127,7 +129,21 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal if proc is None: p.override_settings.pop('save_images_replace_action', None) - process_images(p) + proc = process_images(p) + + if not discard_further_results and proc: + if batch_results: + batch_results.images.extend(proc.images) + batch_results.infotexts.extend(proc.infotexts) + else: + batch_results = proc + + if 0 <= shared.opts.img2img_batch_show_results_limit < len(batch_results.images): + discard_further_results = True + batch_results.images = batch_results.images[:int(shared.opts.img2img_batch_show_results_limit)] + batch_results.infotexts = batch_results.infotexts[:int(shared.opts.img2img_batch_show_results_limit)] + + return batch_results def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args): @@ -212,10 +228,10 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s with closing(p): if is_batch: assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled" + processed = process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, use_png_info=img2img_batch_use_png_info, png_info_props=img2img_batch_png_info_props, png_info_dir=img2img_batch_png_info_dir) - process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, use_png_info=img2img_batch_use_png_info, png_info_props=img2img_batch_png_info_props, png_info_dir=img2img_batch_png_info_dir) - - processed = Processed(p, [], p.seed, "") + if processed is None: + processed = Processed(p, [], p.seed, "") else: processed = modules.scripts.scripts_img2img.run(p, *args) if processed is None: diff --git a/modules/import_hook.py b/modules/import_hook.py index 28c67dfa8..eba9a3729 100644 --- a/modules/import_hook.py +++ b/modules/import_hook.py @@ -3,3 +3,14 @@ import sys # this will break any attempt to import xformers which will prevent stability diffusion repo from trying to use it if "--xformers" not in "".join(sys.argv): sys.modules["xformers"] = None + +# Hack to fix a changed import in torchvision 0.17+, which otherwise breaks +# basicsr; see https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/13985 +try: + import torchvision.transforms.functional_tensor # noqa: F401 +except ImportError: + try: + import torchvision.transforms.functional as functional + sys.modules["torchvision.transforms.functional_tensor"] = functional + except ImportError: + pass # shrug... diff --git a/modules/generation_parameters_copypaste.py b/modules/infotext.py similarity index 81% rename from modules/generation_parameters_copypaste.py rename to modules/infotext.py index 0a606515b..26e9b9493 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/infotext.py @@ -1,23 +1,24 @@ +from __future__ import annotations import base64 import io import json import os import re +import sys import gradio as gr from modules.paths import data_path -from modules import shared, ui_tempdir, script_callbacks, processing +from modules import shared, ui_tempdir, script_callbacks, processing, infotext_versions from PIL import Image +sys.modules['modules.generation_parameters_copypaste'] = sys.modules[__name__] # alias for old name + re_param_code = r'\s*(\w[\w \-/]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)' re_param = re.compile(re_param_code) re_imagesize = re.compile(r"^(\d+)x(\d+)$") re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$") type_of_gr_update = type(gr.update()) -paste_fields = {} -registered_param_bindings = [] - class ParamBinding: def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=None): @@ -30,6 +31,23 @@ class ParamBinding: self.paste_field_names = paste_field_names or [] +class PasteField(tuple): + def __new__(cls, component, target, *, api=None): + return super().__new__(cls, (component, target)) + + def __init__(self, component, target, *, api=None): + super().__init__() + + self.api = api + self.component = component + self.label = target if isinstance(target, str) else None + self.function = target if callable(target) else None + + +paste_fields: dict[str, dict] = {} +registered_param_bindings: list[ParamBinding] = [] + + def reset(): paste_fields.clear() registered_param_bindings.clear() @@ -82,6 +100,12 @@ def image_from_url_text(filedata): def add_paste_fields(tabname, init_img, fields, override_settings_component=None): + + if fields: + for i in range(len(fields)): + if not isinstance(fields[i], PasteField): + fields[i] = PasteField(*fields[i]) + paste_fields[tabname] = {"init_img": init_img, "fields": fields, "override_settings_component": override_settings_component} # backwards compatibility for existing extensions @@ -113,7 +137,6 @@ def register_paste_params_button(binding: ParamBinding): def connect_paste_params_buttons(): - binding: ParamBinding for binding in registered_param_bindings: destination_image_component = paste_fields[binding.tabname]["init_img"] fields = paste_fields[binding.tabname]["fields"] @@ -313,6 +336,17 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model if "VAE Decoder" not in res: res["VAE Decoder"] = "Full" + if "FP8 weight" not in res: + res["FP8 weight"] = "Disable" + + if "Cache FP16 weight for LoRA" not in res and res["FP8 weight"] != "Disable": + res["Cache FP16 weight for LoRA"] = False + + infotext_versions.backcompat(res) + + skip = set(shared.opts.infotext_skip_pasting) + res = {k: v for k, v in res.items() if k not in skip} + return res @@ -361,6 +395,48 @@ def create_override_settings_dict(text_pairs): return res +def get_override_settings(params, *, skip_fields=None): + """Returns a list of settings overrides from the infotext parameters dictionary. + + This function checks the `params` dictionary for any keys that correspond to settings in `shared.opts` and returns + a list of tuples containing the parameter name, setting name, and new value cast to correct type. + + It checks for conditions before adding an override: + - ignores settings that match the current value + - ignores parameter keys present in skip_fields argument. + + Example input: + {"Clip skip": "2"} + + Example output: + [("Clip skip", "CLIP_stop_at_last_layers", 2)] + """ + + res = [] + + mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext] + for param_name, setting_name in mapping + infotext_to_setting_name_mapping: + if param_name in (skip_fields or {}): + continue + + v = params.get(param_name, None) + if v is None: + continue + + if setting_name == "sd_model_checkpoint" and shared.opts.disable_weights_auto_swap: + continue + + v = shared.opts.cast_value(setting_name, v) + current_value = getattr(shared.opts, setting_name, None) + + if v == current_value: + continue + + res.append((param_name, setting_name, v)) + + return res + + def connect_paste(button, paste_fields, input_comp, override_settings_component, tabname): def paste_func(prompt): if not prompt and not shared.cmd_opts.hide_ui_dir_config: @@ -402,29 +478,9 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component, already_handled_fields = {key: 1 for _, key in paste_fields} def paste_settings(params): - vals = {} + vals = get_override_settings(params, skip_fields=already_handled_fields) - mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext] - for param_name, setting_name in mapping + infotext_to_setting_name_mapping: - if param_name in already_handled_fields: - continue - - v = params.get(param_name, None) - if v is None: - continue - - if setting_name == "sd_model_checkpoint" and shared.opts.disable_weights_auto_swap: - continue - - v = shared.opts.cast_value(setting_name, v) - current_value = getattr(shared.opts, setting_name, None) - - if v == current_value: - continue - - vals[param_name] = v - - vals_pairs = [f"{k}: {v}" for k, v in vals.items()] + vals_pairs = [f"{infotext_text}: {value}" for infotext_text, setting_name, value in vals] return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=bool(vals_pairs)) @@ -443,3 +499,4 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component, outputs=[], show_progress=False, ) + diff --git a/modules/infotext_versions.py b/modules/infotext_versions.py new file mode 100644 index 000000000..a5afeebf1 --- /dev/null +++ b/modules/infotext_versions.py @@ -0,0 +1,39 @@ +from modules import shared +from packaging import version +import re + + +v160 = version.parse("1.6.0") +v170_tsnr = version.parse("v1.7.0-225") + + +def parse_version(text): + if text is None: + return None + + m = re.match(r'([^-]+-[^-]+)-.*', text) + if m: + text = m.group(1) + + try: + return version.parse(text) + except Exception: + return None + + +def backcompat(d): + """Checks infotext Version field, and enables backwards compatibility options according to it.""" + + if not shared.opts.auto_backcompat: + return + + ver = parse_version(d.get("Version")) + if ver is None: + return + + if ver < v160: + d["Old prompt editing timelines"] = True + + if ver < v170_tsnr: + d["Downcast alphas_cumprod"] = True + diff --git a/modules/initialize.py b/modules/initialize.py index ac95fc6f0..4a3cd98cf 100644 --- a/modules/initialize.py +++ b/modules/initialize.py @@ -54,9 +54,6 @@ def initialize(): initialize_util.configure_sigint_handler() initialize_util.configure_opts_onchange() - from modules import modelloader - modelloader.cleanup_models() - from modules import sd_models sd_models.setup_model() startup_timer.record("setup SD model") diff --git a/modules/initialize_util.py b/modules/initialize_util.py index 2e9b6d895..b6767138d 100644 --- a/modules/initialize_util.py +++ b/modules/initialize_util.py @@ -177,6 +177,8 @@ def configure_opts_onchange(): shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed) shared.opts.onchange("gradio_theme", shared.reload_gradio_theme) shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False) + shared.opts.onchange("fp8_storage", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False) + shared.opts.onchange("cache_fp16_weight", wrap_queued_call(lambda: sd_models.reload_model_weights(forced_reload=True)), call=False) startup_timer.record("opts onchange") diff --git a/modules/interrogate.py b/modules/interrogate.py index 3045560d0..35a627ca1 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -10,7 +10,7 @@ import torch.hub from torchvision import transforms from torchvision.transforms.functional import InterpolationMode -from modules import devices, paths, shared, lowvram, modelloader, errors +from modules import devices, paths, shared, lowvram, modelloader, errors, torch_utils blip_image_eval_size = 384 clip_model_name = 'ViT-L/14' @@ -131,7 +131,7 @@ class InterrogateModels: self.clip_model = self.clip_model.to(devices.device_interrogate) - self.dtype = next(self.clip_model.parameters()).dtype + self.dtype = torch_utils.get_param(self.clip_model).dtype def send_clip_to_ram(self): if not shared.opts.interrogate_keep_models_in_memory: diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 8cdbafa50..c2cbd8ce7 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -6,6 +6,7 @@ import os import shutil import sys import importlib.util +import importlib.metadata import platform import json from functools import lru_cache @@ -119,11 +120,16 @@ def run(command, desc=None, errdesc=None, custom_env=None, live: bool = default_ def is_installed(package): try: - spec = importlib.util.find_spec(package) - except ModuleNotFoundError: - return False + dist = importlib.metadata.distribution(package) + except importlib.metadata.PackageNotFoundError: + try: + spec = importlib.util.find_spec(package) + except ModuleNotFoundError: + return False - return spec is not None + return spec is not None + + return dist is not None def repo_dir(name): @@ -308,24 +314,42 @@ def requirements_met(requirements_file): def prepare_environment(): - torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118") - torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}") + torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu121") + torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.1.2 torchvision==0.16.2 --extra-index-url {torch_index_url}") + if args.use_ipex: + if platform.system() == "Windows": + # The "Nuullll/intel-extension-for-pytorch" wheels were built from IPEX source for Intel Arc GPU: https://github.com/intel/intel-extension-for-pytorch/tree/xpu-main + # This is NOT an Intel official release so please use it at your own risk!! + # See https://github.com/Nuullll/intel-extension-for-pytorch/releases/tag/v2.0.110%2Bxpu-master%2Bdll-bundle for details. + # + # Strengths (over official IPEX 2.0.110 windows release): + # - AOT build (for Arc GPU only) to eliminate JIT compilation overhead: https://github.com/intel/intel-extension-for-pytorch/issues/399 + # - Bundles minimal oneAPI 2023.2 dependencies into the python wheels, so users don't need to install oneAPI for the whole system. + # - Provides a compatible torchvision wheel: https://github.com/intel/intel-extension-for-pytorch/issues/465 + # Limitation: + # - Only works for python 3.10 + url_prefix = "https://github.com/Nuullll/intel-extension-for-pytorch/releases/download/v2.0.110%2Bxpu-master%2Bdll-bundle" + torch_command = os.environ.get('TORCH_COMMAND', f"pip install {url_prefix}/torch-2.0.0a0+gite9ebda2-cp310-cp310-win_amd64.whl {url_prefix}/torchvision-0.15.2a0+fa99a53-cp310-cp310-win_amd64.whl {url_prefix}/intel_extension_for_pytorch-2.0.110+gitc6ea20b-cp310-cp310-win_amd64.whl") + else: + # Using official IPEX release for linux since it's already an AOT build. + # However, users still have to install oneAPI toolkit and activate oneAPI environment manually. + # See https://intel.github.io/intel-extension-for-pytorch/index.html#installation for details. + torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://pytorch-extension.intel.com/release-whl/stable/xpu/us/") + torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.0a0 intel-extension-for-pytorch==2.0.110+gitba7f6c1 --extra-index-url {torch_index_url}") requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") - xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.20') + xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.23.post1') clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip") openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip") stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git") stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git") k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git') - codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git') blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git') stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf") stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f") k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c") - codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af") blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") try: @@ -352,6 +376,8 @@ def prepare_environment(): run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True) startup_timer.record("install torch") + if args.use_ipex: + args.skip_torch_cuda_test = True if not args.skip_torch_cuda_test and not check_run_python("import torch; assert torch.cuda.is_available()"): raise RuntimeError( 'Torch is not able to use GPU; ' @@ -380,15 +406,10 @@ def prepare_environment(): git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash) git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash) git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash) - git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash) git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash) startup_timer.record("clone repositores") - if not is_installed("lpips"): - run_pip(f"install -r \"{os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}\"", "requirements for CodeFormer") - startup_timer.record("install CodeFormer requirements") - if not os.path.isfile(requirements_file): requirements_file = os.path.join(script_path, requirements_file) @@ -441,7 +462,7 @@ def dump_sysinfo(): import datetime text = sysinfo.get() - filename = f"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.txt" + filename = f"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.json" with open(filename, "w", encoding="utf8") as file: file.write(text) diff --git a/modules/logging_config.py b/modules/logging_config.py index 7db23d4b6..792698756 100644 --- a/modules/logging_config.py +++ b/modules/logging_config.py @@ -1,16 +1,41 @@ import os import logging +try: + from tqdm.auto import tqdm + + class TqdmLoggingHandler(logging.Handler): + def __init__(self, level=logging.INFO): + super().__init__(level) + + def emit(self, record): + try: + msg = self.format(record) + tqdm.write(msg) + self.flush() + except Exception: + self.handleError(record) + + TQDM_IMPORTED = True +except ImportError: + # tqdm does not exist before first launch + # I will import once the UI finishes seting up the enviroment and reloads. + TQDM_IMPORTED = False def setup_logging(loglevel): if loglevel is None: loglevel = os.environ.get("SD_WEBUI_LOG_LEVEL") + loghandlers = [] + + if TQDM_IMPORTED: + loghandlers.append(TqdmLoggingHandler()) + if loglevel: log_level = getattr(logging, loglevel.upper(), None) or logging.INFO logging.basicConfig( level=log_level, format='%(asctime)s %(levelname)s [%(name)s] %(message)s', datefmt='%Y-%m-%d %H:%M:%S', + handlers=loghandlers ) - diff --git a/modules/mac_specific.py b/modules/mac_specific.py index 89256c5b0..d96d86d79 100644 --- a/modules/mac_specific.py +++ b/modules/mac_specific.py @@ -1,6 +1,7 @@ import logging import torch +from torch import Tensor import platform from modules.sd_hijack_utils import CondFunc from packaging import version @@ -51,6 +52,17 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs): return cumsum_func(input, *args, **kwargs) +# MPS workaround for https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046 +def interpolate_with_fp32_fallback(orig_func, *args, **kwargs) -> Tensor: + try: + return orig_func(*args, **kwargs) + except RuntimeError as e: + if "not implemented for" in str(e) and "Half" in str(e): + input_tensor = args[0] + return orig_func(input_tensor.to(torch.float32), *args[1:], **kwargs).to(input_tensor.dtype) + else: + print(f"An unexpected RuntimeError occurred: {str(e)}") + if has_mps: if platform.mac_ver()[0].startswith("13.2."): # MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124) @@ -77,6 +89,9 @@ if has_mps: # MPS workaround for https://github.com/pytorch/pytorch/issues/96113 CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda _, input, *args, **kwargs: len(args) == 4 and input.device.type == 'mps') + # MPS workaround for https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046 + CondFunc('torch.nn.functional.interpolate', interpolate_with_fp32_fallback, None) + # MPS workaround for https://github.com/pytorch/pytorch/issues/92311 if platform.processor() == 'i386': for funcName in ['torch.argmax', 'torch.Tensor.argmax']: diff --git a/modules/modelloader.py b/modules/modelloader.py index 098bcb793..a71941375 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -1,13 +1,20 @@ from __future__ import annotations -import os -import shutil import importlib +import logging +import os +from typing import TYPE_CHECKING from urllib.parse import urlparse +import torch + from modules import shared from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone -from modules.paths import script_path, models_path + +if TYPE_CHECKING: + import spandrel + +logger = logging.getLogger(__name__) def load_file_from_url( @@ -90,54 +97,6 @@ def friendly_name(file: str): return model_name -def cleanup_models(): - # This code could probably be more efficient if we used a tuple list or something to store the src/destinations - # and then enumerate that, but this works for now. In the future, it'd be nice to just have every "model" scaler - # somehow auto-register and just do these things... - root_path = script_path - src_path = models_path - dest_path = os.path.join(models_path, "Stable-diffusion") - move_files(src_path, dest_path, ".ckpt") - move_files(src_path, dest_path, ".safetensors") - src_path = os.path.join(root_path, "ESRGAN") - dest_path = os.path.join(models_path, "ESRGAN") - move_files(src_path, dest_path) - src_path = os.path.join(models_path, "BSRGAN") - dest_path = os.path.join(models_path, "ESRGAN") - move_files(src_path, dest_path, ".pth") - src_path = os.path.join(root_path, "gfpgan") - dest_path = os.path.join(models_path, "GFPGAN") - move_files(src_path, dest_path) - src_path = os.path.join(root_path, "SwinIR") - dest_path = os.path.join(models_path, "SwinIR") - move_files(src_path, dest_path) - src_path = os.path.join(root_path, "repositories/latent-diffusion/experiments/pretrained_models/") - dest_path = os.path.join(models_path, "LDSR") - move_files(src_path, dest_path) - - -def move_files(src_path: str, dest_path: str, ext_filter: str = None): - try: - os.makedirs(dest_path, exist_ok=True) - if os.path.exists(src_path): - for file in os.listdir(src_path): - fullpath = os.path.join(src_path, file) - if os.path.isfile(fullpath): - if ext_filter is not None: - if ext_filter not in file: - continue - print(f"Moving {file} from {src_path} to {dest_path}.") - try: - shutil.move(fullpath, dest_path) - except Exception: - pass - if len(os.listdir(src_path)) == 0: - print(f"Removing empty folder: {src_path}") - shutil.rmtree(src_path, True) - except Exception: - pass - - def load_upscalers(): # We can only do this 'magic' method to dynamically load upscalers if they are referenced, # so we'll try to import any _model.py files before looking in __subclasses__ @@ -177,3 +136,26 @@ def load_upscalers(): # Special case for UpscalerNone keeps it at the beginning of the list. key=lambda x: x.name.lower() if not isinstance(x.scaler, (UpscalerNone, UpscalerLanczos, UpscalerNearest)) else "" ) + + +def load_spandrel_model( + path: str, + *, + device: str | torch.device | None, + half: bool = False, + dtype: str | torch.dtype | None = None, + expected_architecture: str | None = None, +) -> spandrel.ModelDescriptor: + import spandrel + model_descriptor = spandrel.ModelLoader(device=device).load_from_file(path) + if expected_architecture and model_descriptor.architecture != expected_architecture: + logger.warning( + f"Model {path!r} is not a {expected_architecture!r} model (got {model_descriptor.architecture!r})", + ) + if half: + model_descriptor.model.half() + if dtype: + model_descriptor.model.to(dtype=dtype) + model_descriptor.model.eval() + logger.debug("Loaded %s from %s (device=%s, half=%s, dtype=%s)", model_descriptor, path, device, half, dtype) + return model_descriptor diff --git a/modules/models/diffusion/ddpm_edit.py b/modules/models/diffusion/ddpm_edit.py index b892d5fc7..6db340da4 100644 --- a/modules/models/diffusion/ddpm_edit.py +++ b/modules/models/diffusion/ddpm_edit.py @@ -24,10 +24,15 @@ from pytorch_lightning.utilities.distributed import rank_zero_only from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config from ldm.modules.ema import LitEma from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution -from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL +from ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like from ldm.models.diffusion.ddim import DDIMSampler +try: + from ldm.models.autoencoder import VQModelInterface +except Exception: + class VQModelInterface: + pass __conditioning_keys__ = {'concat': 'c_concat', 'crossattn': 'c_crossattn', diff --git a/modules/options.py b/modules/options.py index e270a42df..09ff9403d 100644 --- a/modules/options.py +++ b/modules/options.py @@ -1,5 +1,6 @@ import json import sys +from dataclasses import dataclass import gradio as gr @@ -8,13 +9,14 @@ from modules.shared_cmd_options import cmd_opts class OptionInfo: - def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None, comment_before='', comment_after='', infotext=None, restrict_api=False): + def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None, comment_before='', comment_after='', infotext=None, restrict_api=False, category_id=None): self.default = default self.label = label self.component = component self.component_args = component_args self.onchange = onchange self.section = section + self.category_id = category_id self.refresh = refresh self.do_not_save = False @@ -63,7 +65,11 @@ class OptionHTML(OptionInfo): def options_section(section_identifier, options_dict): for v in options_dict.values(): - v.section = section_identifier + if len(section_identifier) == 2: + v.section = section_identifier + elif len(section_identifier) == 3: + v.section = section_identifier[0:2] + v.category_id = section_identifier[2] return options_dict @@ -76,7 +82,7 @@ class Options: def __init__(self, data_labels: dict[str, OptionInfo], restricted_opts): self.data_labels = data_labels - self.data = {k: v.default for k, v in self.data_labels.items()} + self.data = {k: v.default for k, v in self.data_labels.items() if not v.do_not_save} self.restricted_opts = restricted_opts def __setattr__(self, key, value): @@ -175,7 +181,7 @@ class Options: assert not cmd_opts.freeze_settings, "saving settings is disabled" with open(filename, "w", encoding="utf8") as file: - json.dump(self.data, file, indent=4) + json.dump(self.data, file, indent=4, ensure_ascii=False) def same_type(self, x, y): if x is None or y is None: @@ -223,21 +229,59 @@ class Options: d = {k: self.data.get(k, v.default) for k, v in self.data_labels.items()} d["_comments_before"] = {k: v.comment_before for k, v in self.data_labels.items() if v.comment_before is not None} d["_comments_after"] = {k: v.comment_after for k, v in self.data_labels.items() if v.comment_after is not None} + + item_categories = {} + for item in self.data_labels.values(): + category = categories.mapping.get(item.category_id) + category = "Uncategorized" if category is None else category.label + if category not in item_categories: + item_categories[category] = item.section[1] + + # _categories is a list of pairs: [section, category]. Each section (a setting page) will get a special heading above it with the category as text. + d["_categories"] = [[v, k] for k, v in item_categories.items()] + [["Defaults", "Other"]] + return json.dumps(d) def add_option(self, key, info): self.data_labels[key] = info + if key not in self.data and not info.do_not_save: + self.data[key] = info.default def reorder(self): - """reorder settings so that all items related to section always go together""" + """Reorder settings so that: + - all items related to section always go together + - all sections belonging to a category go together + - sections inside a category are ordered alphabetically + - categories are ordered by creation order + + Category is a superset of sections: for category "postprocessing" there could be multiple sections: "face restoration", "upscaling". + + This function also changes items' category_id so that all items belonging to a section have the same category_id. + """ + + category_ids = {} + section_categories = {} - section_ids = {} settings_items = self.data_labels.items() for _, item in settings_items: - if item.section not in section_ids: - section_ids[item.section] = len(section_ids) + if item.section not in section_categories: + section_categories[item.section] = item.category_id - self.data_labels = dict(sorted(settings_items, key=lambda x: section_ids[x[1].section])) + for _, item in settings_items: + item.category_id = section_categories.get(item.section) + + for category_id in categories.mapping: + if category_id not in category_ids: + category_ids[category_id] = len(category_ids) + + def sort_key(x): + item: OptionInfo = x[1] + category_order = category_ids.get(item.category_id, len(category_ids)) + section_order = item.section[1] + + return category_order, section_order + + self.data_labels = dict(sorted(settings_items, key=sort_key)) def cast_value(self, key, value): """casts an arbitrary to the same type as this setting's value with key @@ -260,3 +304,22 @@ class Options: value = expected_type(value) return value + + +@dataclass +class OptionsCategory: + id: str + label: str + +class OptionsCategories: + def __init__(self): + self.mapping = {} + + def register_category(self, category_id, label): + if category_id in self.mapping: + return category_id + + self.mapping[category_id] = OptionsCategory(category_id, label) + + +categories = OptionsCategories() diff --git a/modules/paths.py b/modules/paths.py index 187b94961..030646519 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -38,7 +38,6 @@ mute_sdxl_imports() path_dirs = [ (sd_path, 'ldm', 'Stable Diffusion', []), (os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', ["sgm"]), - (os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []), (os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []), (os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]), ] diff --git a/modules/paths_internal.py b/modules/paths_internal.py index 89131a54f..b86ecd7f1 100644 --- a/modules/paths_internal.py +++ b/modules/paths_internal.py @@ -28,5 +28,6 @@ models_path = os.path.join(data_path, "models") extensions_dir = os.path.join(data_path, "extensions") extensions_builtin_dir = os.path.join(script_path, "extensions-builtin") config_states_dir = os.path.join(script_path, "config_states") +default_output_dir = os.path.join(data_path, "output") roboto_ttf_file = os.path.join(modules_path, 'Roboto-Regular.ttf') diff --git a/modules/postprocessing.py b/modules/postprocessing.py index cf04d38b0..facea899f 100644 --- a/modules/postprocessing.py +++ b/modules/postprocessing.py @@ -2,7 +2,7 @@ import os from PIL import Image -from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common, generation_parameters_copypaste +from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common from modules.shared import opts @@ -29,11 +29,7 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, image_list = shared.listfiles(input_dir) for filename in image_list: - try: - image = Image.open(filename) - except Exception: - continue - yield image, filename + yield filename, filename else: assert image, 'image not selected' yield image, None @@ -45,43 +41,97 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, infotext = '' - for image_data, name in get_images(extras_mode, image, image_folder, input_dir): + data_to_process = list(get_images(extras_mode, image, image_folder, input_dir)) + shared.state.job_count = len(data_to_process) + + for image_placeholder, name in data_to_process: image_data: Image.Image + shared.state.nextjob() shared.state.textinfo = name + shared.state.skipped = False + + if shared.state.interrupted: + break + + if isinstance(image_placeholder, str): + try: + image_data = Image.open(image_placeholder) + except Exception: + continue + else: + image_data = image_placeholder + + shared.state.assign_current_image(image_data) parameters, existing_pnginfo = images.read_info_from_image(image_data) if parameters: existing_pnginfo["parameters"] = parameters - pp = scripts_postprocessing.PostprocessedImage(image_data.convert("RGB")) + initial_pp = scripts_postprocessing.PostprocessedImage(image_data.convert("RGB")) - scripts.scripts_postproc.run(pp, args) + scripts.scripts_postproc.run(initial_pp, args) - if opts.use_original_name_batch and name is not None: - basename = os.path.splitext(os.path.basename(name))[0] - else: - basename = '' + if shared.state.skipped: + continue - infotext = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in pp.info.items() if v is not None]) + used_suffixes = {} + for pp in [initial_pp, *initial_pp.extra_images]: + suffix = pp.get_suffix(used_suffixes) - if opts.enable_pnginfo: - pp.image.info = existing_pnginfo - pp.image.info["postprocessing"] = infotext + if opts.use_original_name_batch and name is not None: + basename = os.path.splitext(os.path.basename(name))[0] + forced_filename = basename + suffix + else: + basename = '' + forced_filename = None - if save_output: - images.save_image(pp.image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None) + infotext = ", ".join([k if k == v else f'{k}: {infotext.quote(v)}' for k, v in pp.info.items() if v is not None]) - if extras_mode != 2 or show_extras_results: - outputs.append(pp.image) + if opts.enable_pnginfo: + pp.image.info = existing_pnginfo + pp.image.info["postprocessing"] = infotext + + if save_output: + fullfn, _ = images.save_image(pp.image, path=outpath, basename=basename, extension=opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=forced_filename, suffix=suffix) + + if pp.caption: + caption_filename = os.path.splitext(fullfn)[0] + ".txt" + if os.path.isfile(caption_filename): + with open(caption_filename, encoding="utf8") as file: + existing_caption = file.read().strip() + else: + existing_caption = "" + + action = shared.opts.postprocessing_existing_caption_action + if action == 'Prepend' and existing_caption: + caption = f"{existing_caption} {pp.caption}" + elif action == 'Append' and existing_caption: + caption = f"{pp.caption} {existing_caption}" + elif action == 'Keep' and existing_caption: + caption = existing_caption + else: + caption = pp.caption + + caption = caption.strip() + if caption: + with open(caption_filename, "w", encoding="utf8") as file: + file.write(caption) + + if extras_mode != 2 or show_extras_results: + outputs.append(pp.image) image_data.close() devices.torch_gc() - + shared.state.end() return outputs, ui_common.plaintext_to_html(infotext), '' +def run_postprocessing_webui(id_task, *args, **kwargs): + return run_postprocessing(*args, **kwargs) + + def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True): """old handler for API""" @@ -97,9 +147,11 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ "upscaler_2_visibility": extras_upscaler_2_visibility, }, "GFPGAN": { + "enable": True, "gfpgan_visibility": gfpgan_visibility, }, "CodeFormer": { + "enable": True, "codeformer_visibility": codeformer_visibility, "codeformer_weight": codeformer_weight, }, diff --git a/modules/processing.py b/modules/processing.py index 40598f5cf..f55b85ed3 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -16,7 +16,7 @@ from skimage import exposure from typing import Any import modules.sd_hijack -from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng +from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infotext, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng from modules.rng import slerp # noqa: F401 from modules.sd_hijack import model_hijack from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes @@ -62,18 +62,22 @@ def apply_color_correction(correction, original_image): return image.convert('RGB') -def apply_overlay(image, paste_loc, index, overlays): - if overlays is None or index >= len(overlays): +def uncrop(image, dest_size, paste_loc): + x, y, w, h = paste_loc + base_image = Image.new('RGBA', dest_size) + image = images.resize_image(1, image, w, h) + base_image.paste(image, (x, y)) + image = base_image + + return image + + +def apply_overlay(image, paste_loc, overlay): + if overlay is None: return image - overlay = overlays[index] - if paste_loc is not None: - x, y, w, h = paste_loc - base_image = Image.new('RGBA', (overlay.width, overlay.height)) - image = images.resize_image(1, image, w, h) - base_image.paste(image, (x, y)) - image = base_image + image = uncrop(image, (overlay.width, overlay.height), paste_loc) image = image.convert('RGBA') image.alpha_composite(overlay) @@ -81,9 +85,12 @@ def apply_overlay(image, paste_loc, index, overlays): return image -def create_binary_mask(image): +def create_binary_mask(image, round=True): if image.mode == 'RGBA' and image.getextrema()[-1] != (255, 255): - image = image.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0) + if round: + image = image.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0) + else: + image = image.split()[-1].convert("L") else: image = image.convert('L') return image @@ -106,6 +113,21 @@ def txt2img_image_conditioning(sd_model, x, width, height): return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device) else: + sd = sd_model.model.state_dict() + diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None) + if diffusion_model_input is not None: + if diffusion_model_input.shape[1] == 9: + # The "masked-image" in this case will just be all 0.5 since the entire image is masked. + image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5 + image_conditioning = images_tensor_to_samples(image_conditioning, + approximation_indexes.get(opts.sd_vae_encode_method)) + + # Add the fake full 1s mask to the first dimension. + image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) + image_conditioning = image_conditioning.to(x.dtype) + + return image_conditioning + # Dummy zero conditioning if we're not using inpainting or unclip models. # Still takes up a bit of memory, but no encoder call. # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size. @@ -296,7 +318,7 @@ class StableDiffusionProcessing: return conditioning def edit_image_conditioning(self, source_image): - conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method)) + conditioning_image = shared.sd_model.encode_first_stage(source_image).mode() return conditioning_image @@ -308,7 +330,7 @@ class StableDiffusionProcessing: c_adm = torch.cat((c_adm, noise_level_emb), 1) return c_adm - def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None): + def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None, round_image_mask=True): self.is_using_inpainting_conditioning = True # Handle the different mask inputs @@ -320,8 +342,10 @@ class StableDiffusionProcessing: conditioning_mask = conditioning_mask.astype(np.float32) / 255.0 conditioning_mask = torch.from_numpy(conditioning_mask[None, None]) - # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0 - conditioning_mask = torch.round(conditioning_mask) + if round_image_mask: + # Caller is requesting a discretized mask as input, so we round to either 1.0 or 0.0 + conditioning_mask = torch.round(conditioning_mask) + else: conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:]) @@ -345,7 +369,7 @@ class StableDiffusionProcessing: return image_conditioning - def img2img_image_conditioning(self, source_image, latent_image, image_mask=None): + def img2img_image_conditioning(self, source_image, latent_image, image_mask=None, round_image_mask=True): source_image = devices.cond_cast_float(source_image) # HACK: Using introspection as the Depth2Image model doesn't appear to uniquely @@ -357,11 +381,17 @@ class StableDiffusionProcessing: return self.edit_image_conditioning(source_image) if self.sampler.conditioning_key in {'hybrid', 'concat'}: - return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask) + return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask, round_image_mask=round_image_mask) if self.sampler.conditioning_key == "crossattn-adm": return self.unclip_image_conditioning(source_image) + sd = self.sampler.model_wrap.inner_model.model.state_dict() + diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None) + if diffusion_model_input is not None: + if diffusion_model_input.shape[1] == 9: + return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask) + # Dummy zero conditioning if we're not using inpainting or depth model. return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1) @@ -422,6 +452,8 @@ class StableDiffusionProcessing: opts.sdxl_crop_top, self.width, self.height, + opts.fp8_storage, + opts.cache_fp16_weight, ) def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data, hires_steps=None): @@ -596,20 +628,33 @@ def decode_latent_batch(model, batch, target_device=None, check_for_nans=False): sample = decode_first_stage(model, batch[i:i + 1])[0] if check_for_nans: + try: devices.test_for_nans(sample, "vae") except devices.NansException as e: - if devices.dtype_vae == torch.float32 or not shared.opts.auto_vae_precision: + if shared.opts.auto_vae_precision_bfloat16: + autofix_dtype = torch.bfloat16 + autofix_dtype_text = "bfloat16" + autofix_dtype_setting = "Automatically convert VAE to bfloat16" + autofix_dtype_comment = "" + elif shared.opts.auto_vae_precision: + autofix_dtype = torch.float32 + autofix_dtype_text = "32-bit float" + autofix_dtype_setting = "Automatically revert VAE to 32-bit floats" + autofix_dtype_comment = "\nTo always start with 32-bit VAE, use --no-half-vae commandline flag." + else: + raise e + + if devices.dtype_vae == autofix_dtype: raise e errors.print_error_explanation( "A tensor with all NaNs was produced in VAE.\n" - "Web UI will now convert VAE into 32-bit float and retry.\n" - "To disable this behavior, disable the 'Automatically revert VAE to 32-bit floats' setting.\n" - "To always start with 32-bit VAE, use --no-half-vae commandline flag." + f"Web UI will now convert VAE into {autofix_dtype_text} and retry.\n" + f"To disable this behavior, disable the '{autofix_dtype_setting}' setting.{autofix_dtype_comment}" ) - devices.dtype_vae = torch.float32 + devices.dtype_vae = autofix_dtype model.first_stage_model.to(devices.dtype_vae) batch = batch.to(devices.dtype_vae) @@ -679,8 +724,10 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "Size": f"{p.width}x{p.height}", "Model hash": p.sd_model_hash if opts.add_model_hash_to_info else None, "Model": p.sd_model_name if opts.add_model_name_to_info else None, - "VAE hash": p.sd_vae_hash if opts.add_model_hash_to_info else None, - "VAE": p.sd_vae_name if opts.add_model_name_to_info else None, + "FP8 weight": opts.fp8_storage if devices.fp8 else None, + "Cache FP16 weight for LoRA": opts.cache_fp16_weight if devices.fp8 else None, + "VAE hash": p.sd_vae_hash if opts.add_vae_hash_to_info else None, + "VAE": p.sd_vae_name if opts.add_vae_name_to_info else None, "Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])), "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength), "Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), @@ -699,7 +746,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "User": p.user if opts.add_user_name_to_info else None, } - generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None]) + generation_params_text = ", ".join([k if k == v else f'{k}: {infotext.quote(v)}' for k, v in generation_params.items() if v is not None]) prompt_text = p.main_prompt if use_main_prompt else all_prompts[index] negative_prompt_text = f"\nNegative prompt: {p.main_negative_prompt if use_main_prompt else all_negative_prompts[index]}" if all_negative_prompts[index] else "" @@ -799,7 +846,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: infotexts = [] output_images = [] - with torch.no_grad(), p.sd_model.ema_scope(): with devices.autocast(): p.init(p.all_prompts, p.all_seeds, p.all_subseeds) @@ -819,7 +865,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if state.skipped: state.skipped = False - if state.interrupted: + if state.interrupted or state.stopping_generation: break sd_models.reload_model_weights() # model can be changed for example by refiner @@ -865,15 +911,47 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.n_iter > 1: shared.state.job = f"Batch {n+1} out of {p.n_iter}" + def rescale_zero_terminal_snr_abar(alphas_cumprod): + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= (alphas_bar_sqrt_T) + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas_bar[-1] = 4.8973451890853435e-08 + return alphas_bar + + if hasattr(p.sd_model, 'alphas_cumprod') and hasattr(p.sd_model, 'alphas_cumprod_original'): + p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod_original.to(shared.device) + + if opts.use_downcasted_alpha_bar: + p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar + p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod.half().to(shared.device) + if opts.sd_noise_schedule == "Zero Terminal SNR": + p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule + p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device) + with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) + if p.scripts is not None: + ps = scripts.PostSampleArgs(samples_ddim) + p.scripts.post_sample(p, ps) + samples_ddim = ps.samples + if getattr(samples_ddim, 'already_decoded', False): x_samples_ddim = samples_ddim else: if opts.sd_vae_decode_method != 'Full': p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method - x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) x_samples_ddim = torch.stack(x_samples_ddim).float() @@ -886,6 +964,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: devices.torch_gc() + state.nextjob() + if p.scripts is not None: p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n) @@ -922,13 +1002,31 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: pp = scripts.PostprocessImageArgs(image) p.scripts.postprocess_image(p, pp) image = pp.image + + mask_for_overlay = getattr(p, "mask_for_overlay", None) + overlay_image = p.overlay_images[i] if getattr(p, "overlay_images", None) is not None and i < len(p.overlay_images) else None + + if p.scripts is not None: + ppmo = scripts.PostProcessMaskOverlayArgs(i, mask_for_overlay, overlay_image) + p.scripts.postprocess_maskoverlay(p, ppmo) + mask_for_overlay, overlay_image = ppmo.mask_for_overlay, ppmo.overlay_image + if p.color_corrections is not None and i < len(p.color_corrections): if save_samples and opts.save_images_before_color_correction: - image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images) + image_without_cc = apply_overlay(image, p.paste_to, overlay_image) images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction") image = apply_color_correction(p.color_corrections[i], image) - image = apply_overlay(image, p.paste_to, i, p.overlay_images) + # If the intention is to show the output from the model + # that is being composited over the original image, + # we need to keep the original image around + # and use it in the composite step. + original_denoised_image = image.copy() + + if p.paste_to is not None: + original_denoised_image = uncrop(original_denoised_image, (overlay_image.width, overlay_image.height), p.paste_to) + + image = apply_overlay(image, p.paste_to, overlay_image) if save_samples: images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p) @@ -938,28 +1036,26 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if opts.enable_pnginfo: image.info["parameters"] = text output_images.append(image) - if save_samples and hasattr(p, 'mask_for_overlay') and p.mask_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]): - image_mask = p.mask_for_overlay.convert('RGB') - image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA') - if opts.save_mask: - images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask") + if mask_for_overlay is not None: + if opts.return_mask or opts.save_mask: + image_mask = mask_for_overlay.convert('RGB') + if save_samples and opts.save_mask: + images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask") + if opts.return_mask: + output_images.append(image_mask) - if opts.save_mask_composite: - images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite") - - if opts.return_mask: - output_images.append(image_mask) - - if opts.return_mask_composite: - output_images.append(image_mask_composite) + if opts.return_mask_composite or opts.save_mask_composite: + image_mask_composite = Image.composite(original_denoised_image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA') + if save_samples and opts.save_mask_composite: + images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite") + if opts.return_mask_composite: + output_images.append(image_mask_composite) del x_samples_ddim devices.torch_gc() - state.nextjob() - if not infotexts: infotexts.append(Processed(p, []).infotext(p, 0)) @@ -1028,6 +1124,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): hr_sampler_name: str = None hr_prompt: str = '' hr_negative_prompt: str = '' + force_task_id: str = None cached_hr_uc = [None, None] cached_hr_c = [None, None] @@ -1100,7 +1197,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): def init(self, all_prompts, all_seeds, all_subseeds): if self.enable_hr: - if self.hr_checkpoint_name: + if self.hr_checkpoint_name and self.hr_checkpoint_name != 'Use same checkpoint': self.hr_checkpoint_info = sd_models.get_closet_checkpoint_match(self.hr_checkpoint_name) if self.hr_checkpoint_info is None: @@ -1147,6 +1244,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): if not self.enable_hr: return samples + devices.torch_gc() if self.latent_scale_mode is None: decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32) @@ -1156,8 +1254,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): with sd_models.SkipWritingToConfig(): sd_models.reload_model_weights(info=self.hr_checkpoint_info) - devices.torch_gc() - return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts) def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts): @@ -1165,7 +1261,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): return samples self.is_hr_pass = True - target_width = self.hr_upscale_to_x target_height = self.hr_upscale_to_y @@ -1254,7 +1349,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True) self.is_hr_pass = False - return decoded_samples def close(self): @@ -1357,12 +1451,14 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): mask_blur_x: int = 4 mask_blur_y: int = 4 mask_blur: int = None + mask_round: bool = True inpainting_fill: int = 0 inpaint_full_res: bool = True inpaint_full_res_padding: int = 0 inpainting_mask_invert: int = 0 initial_noise_multiplier: float = None latent_mask: Image = None + force_task_id: str = None image_mask: Any = field(default=None, init=False) @@ -1402,7 +1498,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): if image_mask is not None: # image_mask is passed in as RGBA by Gradio to support alpha masks, # but we still want to support binary masks. - image_mask = create_binary_mask(image_mask) + image_mask = create_binary_mask(image_mask, round=self.mask_round) if self.inpainting_mask_invert: image_mask = ImageOps.invert(image_mask) @@ -1448,7 +1544,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): # Save init image if opts.save_init_img: self.init_img_hash = hashlib.md5(img.tobytes()).hexdigest() - images.save_image(img, path=opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, save_to_dirs=False) + images.save_image(img, path=opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, save_to_dirs=False, existing_info=img.info) image = images.flatten(img, opts.img2img_background_color) @@ -1509,7 +1605,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2])) latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255 latmask = latmask[0] - latmask = np.around(latmask) + if self.mask_round: + latmask = np.around(latmask) latmask = np.tile(latmask[None], (4, 1, 1)) self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype) @@ -1521,7 +1618,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): elif self.inpainting_fill == 3: self.init_latent = self.init_latent * self.mask - self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask) + self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask, self.mask_round) def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): x = self.rng.next() @@ -1533,7 +1630,14 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning) if self.mask is not None: - samples = samples * self.nmask + self.init_latent * self.mask + blended_samples = samples * self.nmask + self.init_latent * self.mask + + if self.scripts is not None: + mba = scripts.MaskBlendArgs(samples, self.nmask, self.init_latent, self.mask, blended_samples) + self.scripts.on_mask_blend(self, mba) + blended_samples = mba.blended_latent + + samples = blended_samples del x devices.torch_gc() diff --git a/modules/processing_scripts/refiner.py b/modules/processing_scripts/refiner.py index 29ccb78f9..e9941413f 100644 --- a/modules/processing_scripts/refiner.py +++ b/modules/processing_scripts/refiner.py @@ -1,6 +1,7 @@ import gradio as gr from modules import scripts, sd_models +from modules.infotext import PasteField from modules.ui_common import create_refresh_button from modules.ui_components import InputAccordion @@ -31,9 +32,9 @@ class ScriptRefiner(scripts.ScriptBuiltinUI): return None if info is None else info.title self.infotext_fields = [ - (enable_refiner, lambda d: 'Refiner' in d), - (refiner_checkpoint, lambda d: lookup_checkpoint(d.get('Refiner'))), - (refiner_switch_at, 'Refiner switch at'), + PasteField(enable_refiner, lambda d: 'Refiner' in d), + PasteField(refiner_checkpoint, lambda d: lookup_checkpoint(d.get('Refiner')), api="refiner_checkpoint"), + PasteField(refiner_switch_at, 'Refiner switch at', api="refiner_switch_at"), ] return enable_refiner, refiner_checkpoint, refiner_switch_at diff --git a/modules/processing_scripts/seed.py b/modules/processing_scripts/seed.py index dc9c2da50..602932785 100644 --- a/modules/processing_scripts/seed.py +++ b/modules/processing_scripts/seed.py @@ -3,6 +3,7 @@ import json import gradio as gr from modules import scripts, ui, errors +from modules.infotext import PasteField from modules.shared import cmd_opts from modules.ui_components import ToolButton @@ -51,12 +52,12 @@ class ScriptSeed(scripts.ScriptBuiltinUI): seed_checkbox.change(lambda x: gr.update(visible=x), show_progress=False, inputs=[seed_checkbox], outputs=[seed_extras]) self.infotext_fields = [ - (self.seed, "Seed"), - (seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d), - (subseed, "Variation seed"), - (subseed_strength, "Variation seed strength"), - (seed_resize_from_w, "Seed resize from-1"), - (seed_resize_from_h, "Seed resize from-2"), + PasteField(self.seed, "Seed", api="seed"), + PasteField(seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d), + PasteField(subseed, "Variation seed", api="subseed"), + PasteField(subseed_strength, "Variation seed strength", api="subseed_strength"), + PasteField(seed_resize_from_w, "Seed resize from-1", api="seed_resize_from_h"), + PasteField(seed_resize_from_h, "Seed resize from-2", api="seed_resize_from_w"), ] self.on_after_component(lambda x: connect_reuse_seed(self.seed, reuse_seed, x.component, False), elem_id=f'generation_info_{self.tabname}') diff --git a/modules/progress.py b/modules/progress.py index 69921de72..85255e821 100644 --- a/modules/progress.py +++ b/modules/progress.py @@ -8,10 +8,13 @@ from pydantic import BaseModel, Field from modules.shared import opts import modules.shared as shared - +from collections import OrderedDict +import string +import random +from typing import List current_task = None -pending_tasks = {} +pending_tasks = OrderedDict() finished_tasks = [] recorded_results = [] recorded_results_limit = 2 @@ -34,6 +37,11 @@ def finish_task(id_task): if len(finished_tasks) > 16: finished_tasks.pop(0) +def create_task_id(task_type): + N = 7 + res = ''.join(random.choices(string.ascii_uppercase + + string.digits, k=N)) + return f"task({task_type}-{res})" def record_results(id_task, res): recorded_results.append((id_task, res)) @@ -44,6 +52,9 @@ def record_results(id_task, res): def add_task_to_queue(id_job): pending_tasks[id_job] = time.time() +class PendingTasksResponse(BaseModel): + size: int = Field(title="Pending task size") + tasks: List[str] = Field(title="Pending task ids") class ProgressRequest(BaseModel): id_task: str = Field(default=None, title="Task ID", description="id of the task to get progress for") @@ -63,9 +74,16 @@ class ProgressResponse(BaseModel): def setup_progress_api(app): + app.add_api_route("/internal/pending-tasks", get_pending_tasks, methods=["GET"]) return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=ProgressResponse) +def get_pending_tasks(): + pending_tasks_ids = list(pending_tasks) + pending_len = len(pending_tasks_ids) + return PendingTasksResponse(size=pending_len, tasks=pending_tasks_ids) + + def progressapi(req: ProgressRequest): active = req.id_task == current_task queued = req.id_task in pending_tasks diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index ddf4d2dd4..cba134554 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -4,7 +4,7 @@ import re from collections import namedtuple import lark -# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]" +# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][: in background:0.25] [shoddy:masterful:0.5]" # will be represented with prompt_schedule like this (assuming steps=100): # [25, 'fantasy landscape with a mountain and an oak in foreground shoddy'] # [50, 'fantasy landscape with a lake and an oak in foreground in background shoddy'] diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index 02841c302..4d35b695c 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -1,12 +1,9 @@ import os -import numpy as np -from PIL import Image -from realesrgan import RealESRGANer - -from modules.upscaler import Upscaler, UpscalerData -from modules.shared import cmd_opts, opts from modules import modelloader, errors +from modules.shared import cmd_opts, opts +from modules.upscaler import Upscaler, UpscalerData +from modules.upscaler_utils import upscale_with_model class UpscalerRealESRGAN(Upscaler): @@ -14,29 +11,20 @@ class UpscalerRealESRGAN(Upscaler): self.name = "RealESRGAN" self.user_path = path super().__init__() - try: - from basicsr.archs.rrdbnet_arch import RRDBNet # noqa: F401 - from realesrgan import RealESRGANer # noqa: F401 - from realesrgan.archs.srvgg_arch import SRVGGNetCompact # noqa: F401 - self.enable = True - self.scalers = [] - scalers = self.load_models(path) + self.enable = True + self.scalers = [] + scalers = get_realesrgan_models(self) - local_model_paths = self.find_models(ext_filter=[".pth"]) - for scaler in scalers: - if scaler.local_data_path.startswith("http"): - filename = modelloader.friendly_name(scaler.local_data_path) - local_model_candidates = [local_model for local_model in local_model_paths if local_model.endswith(f"{filename}.pth")] - if local_model_candidates: - scaler.local_data_path = local_model_candidates[0] + local_model_paths = self.find_models(ext_filter=[".pth"]) + for scaler in scalers: + if scaler.local_data_path.startswith("http"): + filename = modelloader.friendly_name(scaler.local_data_path) + local_model_candidates = [local_model for local_model in local_model_paths if local_model.endswith(f"{filename}.pth")] + if local_model_candidates: + scaler.local_data_path = local_model_candidates[0] - if scaler.name in opts.realesrgan_enabled_models: - self.scalers.append(scaler) - - except Exception: - errors.report("Error importing Real-ESRGAN", exc_info=True) - self.enable = False - self.scalers = [] + if scaler.name in opts.realesrgan_enabled_models: + self.scalers.append(scaler) def do_upscale(self, img, path): if not self.enable: @@ -48,20 +36,19 @@ class UpscalerRealESRGAN(Upscaler): errors.report(f"Unable to load RealESRGAN model {path}", exc_info=True) return img - upsampler = RealESRGANer( - scale=info.scale, - model_path=info.local_data_path, - model=info.model(), - half=not cmd_opts.no_half and not cmd_opts.upcast_sampling, - tile=opts.ESRGAN_tile, - tile_pad=opts.ESRGAN_tile_overlap, + model_descriptor = modelloader.load_spandrel_model( + info.local_data_path, device=self.device, + half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling), + expected_architecture="ESRGAN", # "RealESRGAN" isn't a specific thing for Spandrel + ) + return upscale_with_model( + model_descriptor, + img, + tile_size=opts.ESRGAN_tile, + tile_overlap=opts.ESRGAN_tile_overlap, + # TODO: `outscale`? ) - - upsampled = upsampler.enhance(np.array(img), outscale=info.scale)[0] - - image = Image.fromarray(upsampled) - return image def load_model(self, path): for scaler in self.scalers: @@ -76,58 +63,43 @@ class UpscalerRealESRGAN(Upscaler): return scaler raise ValueError(f"Unable to find model info: {path}") - def load_models(self, _): - return get_realesrgan_models(self) - -def get_realesrgan_models(scaler): - try: - from basicsr.archs.rrdbnet_arch import RRDBNet - from realesrgan.archs.srvgg_arch import SRVGGNetCompact - models = [ - UpscalerData( - name="R-ESRGAN General 4xV3", - path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth", - scale=4, - upscaler=scaler, - model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') - ), - UpscalerData( - name="R-ESRGAN General WDN 4xV3", - path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth", - scale=4, - upscaler=scaler, - model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') - ), - UpscalerData( - name="R-ESRGAN AnimeVideo", - path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth", - scale=4, - upscaler=scaler, - model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu') - ), - UpscalerData( - name="R-ESRGAN 4x+", - path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth", - scale=4, - upscaler=scaler, - model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) - ), - UpscalerData( - name="R-ESRGAN 4x+ Anime6B", - path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth", - scale=4, - upscaler=scaler, - model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) - ), - UpscalerData( - name="R-ESRGAN 2x+", - path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth", - scale=2, - upscaler=scaler, - model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) - ), - ] - return models - except Exception: - errors.report("Error making Real-ESRGAN models list", exc_info=True) +def get_realesrgan_models(scaler: UpscalerRealESRGAN): + return [ + UpscalerData( + name="R-ESRGAN General 4xV3", + path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth", + scale=4, + upscaler=scaler, + ), + UpscalerData( + name="R-ESRGAN General WDN 4xV3", + path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth", + scale=4, + upscaler=scaler, + ), + UpscalerData( + name="R-ESRGAN AnimeVideo", + path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth", + scale=4, + upscaler=scaler, + ), + UpscalerData( + name="R-ESRGAN 4x+", + path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth", + scale=4, + upscaler=scaler, + ), + UpscalerData( + name="R-ESRGAN 4x+ Anime6B", + path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth", + scale=4, + upscaler=scaler, + ), + UpscalerData( + name="R-ESRGAN 2x+", + path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth", + scale=2, + upscaler=scaler, + ), + ] diff --git a/modules/rng.py b/modules/rng.py index 9e8ba2ee9..8934d39bf 100644 --- a/modules/rng.py +++ b/modules/rng.py @@ -110,7 +110,7 @@ class ImageRNG: self.is_first = True def first(self): - noise_shape = self.shape if self.seed_resize_from_h <= 0 or self.seed_resize_from_w <= 0 else (self.shape[0], self.seed_resize_from_h // 8, self.seed_resize_from_w // 8) + noise_shape = self.shape if self.seed_resize_from_h <= 0 or self.seed_resize_from_w <= 0 else (self.shape[0], int(self.seed_resize_from_h) // 8, int(self.seed_resize_from_w // 8)) xs = [] diff --git a/modules/scripts.py b/modules/scripts.py index 5c6e0226e..017aed5a0 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -11,11 +11,31 @@ from modules import shared, paths, script_callbacks, extensions, script_loading, AlwaysVisible = object() +class MaskBlendArgs: + def __init__(self, current_latent, nmask, init_latent, mask, blended_latent, denoiser=None, sigma=None): + self.current_latent = current_latent + self.nmask = nmask + self.init_latent = init_latent + self.mask = mask + self.blended_latent = blended_latent + + self.denoiser = denoiser + self.is_final_blend = denoiser is None + self.sigma = sigma + +class PostSampleArgs: + def __init__(self, samples): + self.samples = samples class PostprocessImageArgs: def __init__(self, image): self.image = image +class PostProcessMaskOverlayArgs: + def __init__(self, index, mask_for_overlay, overlay_image): + self.index = index + self.mask_for_overlay = mask_for_overlay + self.overlay_image = overlay_image class PostprocessBatchListArgs: def __init__(self, images): @@ -206,6 +226,25 @@ class Script: pass + def on_mask_blend(self, p, mba: MaskBlendArgs, *args): + """ + Called in inpainting mode when the original content is blended with the inpainted content. + This is called at every step in the denoising process and once at the end. + If is_final_blend is true, this is called for the final blending stage. + Otherwise, denoiser and sigma are defined and may be used to inform the procedure. + """ + + pass + + def post_sample(self, p, ps: PostSampleArgs, *args): + """ + Called after the samples have been generated, + but before they have been decoded by the VAE, if applicable. + Check getattr(samples, 'already_decoded', False) to test if the images are decoded. + """ + + pass + def postprocess_image(self, p, pp: PostprocessImageArgs, *args): """ Called for every image after it has been generated. @@ -213,6 +252,13 @@ class Script: pass + def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs, *args): + """ + Called for every image after it has been generated. + """ + + pass + def postprocess(self, p, processed, *args): """ This function is called after processing ends for AlwaysVisible scripts. @@ -311,20 +357,113 @@ scripts_data = [] postprocessing_scripts_data = [] ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"]) +def topological_sort(dependencies): + """Accepts a dictionary mapping name to its dependencies, returns a list of names ordered according to dependencies. + Ignores errors relating to missing dependeencies or circular dependencies + """ + + visited = {} + result = [] + + def inner(name): + visited[name] = True + + for dep in dependencies.get(name, []): + if dep in dependencies and dep not in visited: + inner(dep) + + result.append(name) + + for depname in dependencies: + if depname not in visited: + inner(depname) + + return result + + +@dataclass +class ScriptWithDependencies: + script_canonical_name: str + file: ScriptFile + requires: list + load_before: list + load_after: list + def list_scripts(scriptdirname, extension, *, include_extensions=True): - scripts_list = [] + scripts = {} - basedir = os.path.join(paths.script_path, scriptdirname) - if os.path.exists(basedir): - for filename in sorted(os.listdir(basedir)): - scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename))) + loaded_extensions = {ext.canonical_name: ext for ext in extensions.active()} + loaded_extensions_scripts = {ext.canonical_name: [] for ext in extensions.active()} + + # build script dependency map + root_script_basedir = os.path.join(paths.script_path, scriptdirname) + if os.path.exists(root_script_basedir): + for filename in sorted(os.listdir(root_script_basedir)): + if not os.path.isfile(os.path.join(root_script_basedir, filename)): + continue + + if os.path.splitext(filename)[1].lower() != extension: + continue + + script_file = ScriptFile(paths.script_path, filename, os.path.join(root_script_basedir, filename)) + scripts[filename] = ScriptWithDependencies(filename, script_file, [], [], []) if include_extensions: for ext in extensions.active(): - scripts_list += ext.list_files(scriptdirname, extension) + extension_scripts_list = ext.list_files(scriptdirname, extension) + for extension_script in extension_scripts_list: + if not os.path.isfile(extension_script.path): + continue - scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)] + script_canonical_name = ("builtin/" if ext.is_builtin else "") + ext.canonical_name + "/" + extension_script.filename + relative_path = scriptdirname + "/" + extension_script.filename + + script = ScriptWithDependencies( + script_canonical_name=script_canonical_name, + file=extension_script, + requires=ext.metadata.get_script_requirements("Requires", relative_path, scriptdirname), + load_before=ext.metadata.get_script_requirements("Before", relative_path, scriptdirname), + load_after=ext.metadata.get_script_requirements("After", relative_path, scriptdirname), + ) + + scripts[script_canonical_name] = script + loaded_extensions_scripts[ext.canonical_name].append(script) + + for script_canonical_name, script in scripts.items(): + # load before requires inverse dependency + # in this case, append the script name into the load_after list of the specified script + for load_before in script.load_before: + # if this requires an individual script to be loaded before + other_script = scripts.get(load_before) + if other_script: + other_script.load_after.append(script_canonical_name) + + # if this requires an extension + other_extension_scripts = loaded_extensions_scripts.get(load_before) + if other_extension_scripts: + for other_script in other_extension_scripts: + other_script.load_after.append(script_canonical_name) + + # if After mentions an extension, remove it and instead add all of its scripts + for load_after in list(script.load_after): + if load_after not in scripts and load_after in loaded_extensions_scripts: + script.load_after.remove(load_after) + + for other_script in loaded_extensions_scripts.get(load_after, []): + script.load_after.append(other_script.script_canonical_name) + + dependencies = {} + + for script_canonical_name, script in scripts.items(): + for required_script in script.requires: + if required_script not in scripts and required_script not in loaded_extensions: + errors.report(f'Script "{script_canonical_name}" requires "{required_script}" to be loaded, but it is not.', exc_info=False) + + dependencies[script_canonical_name] = script.load_after + + ordered_scripts = topological_sort(dependencies) + scripts_list = [scripts[script_canonical_name].file for script_canonical_name in ordered_scripts] return scripts_list @@ -365,15 +504,9 @@ def load_scripts(): elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing): postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module)) - def orderby(basedir): - # 1st webui, 2nd extensions-builtin, 3rd extensions - priority = {os.path.join(paths.script_path, "extensions-builtin"):1, paths.script_path:0} - for key in priority: - if basedir.startswith(key): - return priority[key] - return 9999 - - for scriptfile in sorted(scripts_list, key=lambda x: [orderby(x.basedir), x]): + # here the scripts_list is already ordered + # processing_script is not considered though + for scriptfile in scripts_list: try: if scriptfile.basedir != paths.script_path: sys.path = [scriptfile.basedir] + sys.path @@ -433,7 +566,12 @@ class ScriptRunner: auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data() for script_data in auto_processing_scripts + scripts_data: - script = script_data.script_class() + try: + script = script_data.script_class() + except Exception: + errors.report(f"Error # failed to initialize Script {script_data.module}: ", exc_info=True) + continue + script.filename = script_data.path script.is_txt2img = not is_img2img script.is_img2img = is_img2img @@ -473,17 +611,25 @@ class ScriptRunner: on_after.clear() def create_script_ui(self, script): - import modules.api.models as api_models script.args_from = len(self.inputs) script.args_to = len(self.inputs) + try: + self.create_script_ui_inner(script) + except Exception: + errors.report(f"Error creating UI for {script.name}: ", exc_info=True) + + def create_script_ui_inner(self, script): + import modules.api.models as api_models + controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img) if controls is None: return script.name = wrap_call(script.title, script.filename, "title", default=script.filename).lower() + api_args = [] for control in controls: @@ -550,6 +696,8 @@ class ScriptRunner: self.setup_ui_for_section(None, self.selectable_scripts) def select_script(script_index): + if script_index is None: + script_index = 0 selected_script = self.selectable_scripts[script_index - 1] if script_index>0 else None return [gr.update(visible=selected_script == s) for s in self.selectable_scripts] @@ -593,7 +741,7 @@ class ScriptRunner: def run(self, p, *args): script_index = args[0] - if script_index == 0: + if script_index == 0 or script_index is None: return None script = self.selectable_scripts[script_index-1] @@ -672,6 +820,22 @@ class ScriptRunner: except Exception: errors.report(f"Error running postprocess_batch_list: {script.filename}", exc_info=True) + def post_sample(self, p, ps: PostSampleArgs): + for script in self.alwayson_scripts: + try: + script_args = p.script_args[script.args_from:script.args_to] + script.post_sample(p, ps, *script_args) + except Exception: + errors.report(f"Error running post_sample: {script.filename}", exc_info=True) + + def on_mask_blend(self, p, mba: MaskBlendArgs): + for script in self.alwayson_scripts: + try: + script_args = p.script_args[script.args_from:script.args_to] + script.on_mask_blend(p, mba, *script_args) + except Exception: + errors.report(f"Error running post_sample: {script.filename}", exc_info=True) + def postprocess_image(self, p, pp: PostprocessImageArgs): for script in self.alwayson_scripts: try: @@ -680,6 +844,14 @@ class ScriptRunner: except Exception: errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True) + def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs): + for script in self.alwayson_scripts: + try: + script_args = p.script_args[script.args_from:script.args_to] + script.postprocess_maskoverlay(p, ppmo, *script_args) + except Exception: + errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True) + def before_component(self, component, **kwargs): for callback, script in self.on_before_component_elem_id.get(kwargs.get("elem_id"), []): try: diff --git a/modules/scripts_postprocessing.py b/modules/scripts_postprocessing.py index bac1335dc..901cad080 100644 --- a/modules/scripts_postprocessing.py +++ b/modules/scripts_postprocessing.py @@ -1,13 +1,56 @@ +import dataclasses import os import gradio as gr from modules import errors, shared +@dataclasses.dataclass +class PostprocessedImageSharedInfo: + target_width: int = None + target_height: int = None + + class PostprocessedImage: def __init__(self, image): self.image = image self.info = {} + self.shared = PostprocessedImageSharedInfo() + self.extra_images = [] + self.nametags = [] + self.disable_processing = False + self.caption = None + + def get_suffix(self, used_suffixes=None): + used_suffixes = {} if used_suffixes is None else used_suffixes + suffix = "-".join(self.nametags) + if suffix: + suffix = "-" + suffix + + if suffix not in used_suffixes: + used_suffixes[suffix] = 1 + return suffix + + for i in range(1, 100): + proposed_suffix = suffix + "-" + str(i) + + if proposed_suffix not in used_suffixes: + used_suffixes[proposed_suffix] = 1 + return proposed_suffix + + return suffix + + def create_copy(self, new_image, *, nametags=None, disable_processing=False): + pp = PostprocessedImage(new_image) + pp.shared = self.shared + pp.nametags = self.nametags.copy() + pp.info = self.info.copy() + pp.disable_processing = disable_processing + + if nametags is not None: + pp.nametags += nametags + + return pp class ScriptPostprocessing: @@ -42,10 +85,17 @@ class ScriptPostprocessing: pass - def image_changed(self): + def process_firstpass(self, pp: PostprocessedImage, **args): + """ + Called for all scripts before calling process(). Scripts can examine the image here and set fields + of the pp object to communicate things to other scripts. + args contains a dictionary with all values returned by components from ui() + """ + pass - + def image_changed(self): + pass def wrap_call(func, filename, funcname, *args, default=None, **kwargs): @@ -118,16 +168,42 @@ class ScriptPostprocessingRunner: return inputs def run(self, pp: PostprocessedImage, args): - for script in self.scripts_in_preferred_order(): - shared.state.job = script.name + scripts = [] + for script in self.scripts_in_preferred_order(): script_args = args[script.args_from:script.args_to] process_args = {} for (name, _component), value in zip(script.controls.items(), script_args): process_args[name] = value - script.process(pp, **process_args) + scripts.append((script, process_args)) + + for script, process_args in scripts: + script.process_firstpass(pp, **process_args) + + all_images = [pp] + + for script, process_args in scripts: + if shared.state.skipped: + break + + shared.state.job = script.name + + for single_image in all_images.copy(): + + if not single_image.disable_processing: + script.process(single_image, **process_args) + + for extra_image in single_image.extra_images: + if not isinstance(extra_image, PostprocessedImage): + extra_image = single_image.create_copy(extra_image) + + all_images.append(extra_image) + + single_image.extra_images.clear() + + pp.extra_images = all_images[1:] def create_args_for_run(self, scripts_args): if not self.ui_created: diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py index 8863107ae..273a7edd8 100644 --- a/modules/sd_disable_initialization.py +++ b/modules/sd_disable_initialization.py @@ -215,7 +215,7 @@ class LoadStateDictOnMeta(ReplaceHelper): would be on the meta device. """ - if state_dict == sd: + if state_dict is sd: state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()} original(module, state_dict, strict=strict) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index bc5fbcd3f..e139d9964 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -38,8 +38,12 @@ ldm.models.diffusion.ddpm.print = shared.ldm_print optimizers = [] current_optimizer: sd_hijack_optimizations.SdOptimization = None -ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sd_unet.UNetModel_forward) -sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sd_unet.UNetModel_forward) +ldm_patched_forward = sd_unet.create_unet_forward(ldm.modules.diffusionmodules.openaimodel.UNetModel.forward) +ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", ldm_patched_forward) + +sgm_patched_forward = sd_unet.create_unet_forward(sgm.modules.diffusionmodules.openaimodel.UNetModel.forward) +sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sgm_patched_forward) + def list_optimizers(): new_optimizers = script_callbacks.list_optimizers_callback() @@ -184,6 +188,20 @@ class StableDiffusionModelHijack: errors.display(e, "applying cross attention optimization") undo_optimizations() + def convert_sdxl_to_ssd(self, m): + """Converts an SDXL model to a Segmind Stable Diffusion model (see https://huggingface.co/segmind/SSD-1B)""" + + delattr(m.model.diffusion_model.middle_block, '1') + delattr(m.model.diffusion_model.middle_block, '2') + for i in ['9', '8', '7', '6', '5', '4']: + delattr(m.model.diffusion_model.input_blocks[7][1].transformer_blocks, i) + delattr(m.model.diffusion_model.input_blocks[8][1].transformer_blocks, i) + delattr(m.model.diffusion_model.output_blocks[0][1].transformer_blocks, i) + delattr(m.model.diffusion_model.output_blocks[1][1].transformer_blocks, i) + delattr(m.model.diffusion_model.output_blocks[4][1].transformer_blocks, '1') + delattr(m.model.diffusion_model.output_blocks[5][1].transformer_blocks, '1') + devices.torch_gc() + def hijack(self, m): conditioner = getattr(m, 'conditioner', None) if conditioner: @@ -242,8 +260,12 @@ class StableDiffusionModelHijack: self.layers = flatten(m) + import modules.models.diffusion.ddpm_edit + if isinstance(m, ldm.models.diffusion.ddpm.LatentDiffusion): sd_unet.original_forward = ldm_original_forward + elif isinstance(m, modules.models.diffusion.ddpm_edit.LatentDiffusion): + sd_unet.original_forward = ldm_original_forward elif isinstance(m, sgm.models.diffusion.DiffusionEngine): sd_unet.original_forward = sgm_original_forward else: @@ -285,8 +307,6 @@ class StableDiffusionModelHijack: self.layers = None self.clip = None - sd_unet.original_forward = None - def apply_circular(self, enable): if self.circular_enabled == enable: diff --git a/modules/sd_models.py b/modules/sd_models.py index 3b6cdea18..50bc209e4 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -230,15 +230,19 @@ def select_checkpoint(): return checkpoint_info -checkpoint_dict_replacements = { +checkpoint_dict_replacements_sd1 = { 'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.', 'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.', 'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.', } +checkpoint_dict_replacements_sd2_turbo = { # Converts SD 2.1 Turbo from SGM to LDM format. + 'conditioner.embedders.0.': 'cond_stage_model.', +} -def transform_checkpoint_dict_key(k): - for text, replacement in checkpoint_dict_replacements.items(): + +def transform_checkpoint_dict_key(k, replacements): + for text, replacement in replacements.items(): if k.startswith(text): k = replacement + k[len(text):] @@ -249,9 +253,14 @@ def get_state_dict_from_checkpoint(pl_sd): pl_sd = pl_sd.pop("state_dict", pl_sd) pl_sd.pop("state_dict", None) + is_sd2_turbo = 'conditioner.embedders.0.model.ln_final.weight' in pl_sd and pl_sd['conditioner.embedders.0.model.ln_final.weight'].size()[0] == 1024 + sd = {} for k, v in pl_sd.items(): - new_key = transform_checkpoint_dict_key(k) + if is_sd2_turbo: + new_key = transform_checkpoint_dict_key(k, checkpoint_dict_replacements_sd2_turbo) + else: + new_key = transform_checkpoint_dict_key(k, checkpoint_dict_replacements_sd1) if new_key is not None: sd[new_key] = v @@ -339,10 +348,28 @@ class SkipWritingToConfig: SkipWritingToConfig.skip = self.previous +def check_fp8(model): + if model is None: + return None + if devices.get_optimal_device_name() == "mps": + enable_fp8 = False + elif shared.opts.fp8_storage == "Enable": + enable_fp8 = True + elif getattr(model, "is_sdxl", False) and shared.opts.fp8_storage == "Enable for SDXL": + enable_fp8 = True + else: + enable_fp8 = False + return enable_fp8 + + def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer): sd_model_hash = checkpoint_info.calculate_shorthash() timer.record("calculate hash") + if devices.fp8: + # prevent model to load state dict in fp8 + model.half() + if not SkipWritingToConfig.skip: shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title @@ -352,10 +379,13 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer model.is_sdxl = hasattr(model, 'conditioner') model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model') model.is_sd1 = not model.is_sdxl and not model.is_sd2 - + model.is_ssd = model.is_sdxl and 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys() if model.is_sdxl: sd_models_xl.extend_sdxl(model) + if model.is_ssd: + sd_hijack.model_hijack.convert_sdxl_to_ssd(model) + if shared.opts.sd_checkpoint_cache > 0: # cache newly loaded model checkpoints_loaded[checkpoint_info] = state_dict.copy() @@ -371,6 +401,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if shared.cmd_opts.no_half: model.float() + model.alphas_cumprod_original = model.alphas_cumprod devices.dtype_unet = torch.float32 timer.record("apply float()") else: @@ -384,7 +415,11 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if shared.cmd_opts.upcast_sampling and depth_model: model.depth_model = None + alphas_cumprod = model.alphas_cumprod + model.alphas_cumprod = None model.half() + model.alphas_cumprod = alphas_cumprod + model.alphas_cumprod_original = alphas_cumprod model.first_stage_model = vae if depth_model: model.depth_model = depth_model @@ -392,6 +427,28 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer devices.dtype_unet = torch.float16 timer.record("apply half()") + for module in model.modules(): + if hasattr(module, 'fp16_weight'): + del module.fp16_weight + if hasattr(module, 'fp16_bias'): + del module.fp16_bias + + if check_fp8(model): + devices.fp8 = True + first_stage = model.first_stage_model + model.first_stage_model = None + for module in model.modules(): + if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)): + if shared.opts.cache_fp16_weight: + module.fp16_weight = module.weight.data.clone().cpu().half() + if module.bias is not None: + module.fp16_bias = module.bias.data.clone().cpu().half() + module.to(torch.float8_e4m3fn) + model.first_stage_model = first_stage + timer.record("apply fp8") + else: + devices.fp8 = False + devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 model.first_stage_model.to(devices.dtype_vae) @@ -639,6 +696,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): else: weight_dtype_conversion = { 'first_stage_model': None, + 'alphas_cumprod': None, '': torch.float16, } @@ -734,7 +792,7 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer): return None -def reload_model_weights(sd_model=None, info=None): +def reload_model_weights(sd_model=None, info=None, forced_reload=False): checkpoint_info = info or select_checkpoint() timer = Timer() @@ -746,11 +804,14 @@ def reload_model_weights(sd_model=None, info=None): current_checkpoint_info = None else: current_checkpoint_info = sd_model.sd_checkpoint_info - if sd_model.sd_model_checkpoint == checkpoint_info.filename: + if check_fp8(sd_model) != devices.fp8: + # load from state dict again to prevent extra numerical errors + forced_reload = True + elif sd_model.sd_model_checkpoint == checkpoint_info.filename and not forced_reload: return sd_model sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer) - if sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename: + if not forced_reload and sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename: return sd_model if sd_model is not None: diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index deab2f6e2..b38137eb5 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -15,6 +15,7 @@ config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml") config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml") config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml") config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml") +config_sdxl_inpainting = os.path.join(sd_configs_path, "sd_xl_inpaint.yaml") config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml") config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml") config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml") @@ -71,7 +72,10 @@ def guess_model_config_from_state_dict(sd, filename): sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None) if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None: - return config_sdxl + if diffusion_model_input.shape[1] == 9: + return config_sdxl_inpainting + else: + return config_sdxl if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None: return config_sdxl_refiner elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None: diff --git a/modules/sd_models_types.py b/modules/sd_models_types.py index 5ffd2f4f9..f911fbb68 100644 --- a/modules/sd_models_types.py +++ b/modules/sd_models_types.py @@ -22,7 +22,10 @@ class WebuiSdModel(LatentDiffusion): """structure with additional information about the file with model's weights""" is_sdxl: bool - """True if the model's architecture is SDXL""" + """True if the model's architecture is SDXL or SSD""" + + is_ssd: bool + """True if the model is SSD""" is_sd2: bool """True if the model's architecture is SD 2.x""" diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index 011233216..0de17af3d 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -6,6 +6,7 @@ import sgm.models.diffusion import sgm.modules.diffusionmodules.denoiser_scaling import sgm.modules.diffusionmodules.discretizer from modules import devices, shared, prompt_parser +from modules import torch_utils def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]): @@ -34,6 +35,12 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond): + sd = self.model.state_dict() + diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None) + if diffusion_model_input is not None: + if diffusion_model_input.shape[1] == 9: + x = torch.cat([x] + cond['c_concat'], dim=1) + return self.model(x, t, cond) @@ -84,7 +91,7 @@ sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt def extend_sdxl(model): """this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase.""" - dtype = next(model.model.diffusion_model.parameters()).dtype + dtype = torch_utils.get_param(model.model.diffusion_model).dtype model.model.diffusion_model.dtype = dtype model.model.conditioning_key = 'crossattn' model.cond_stage_key = 'txt' @@ -93,7 +100,7 @@ def extend_sdxl(model): model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps" discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization() - model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype) + model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=torch.float32) model.conditioner.wrapped = torch.nn.Module() diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index b8101d38d..eb9d5dafa 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -56,6 +56,9 @@ class CFGDenoiser(torch.nn.Module): self.sampler = sampler self.model_wrap = None self.p = None + + # NOTE: masking before denoising can cause the original latents to be oversmoothed + # as the original latents do not have noise self.mask_before_denoising = False @property @@ -105,8 +108,21 @@ class CFGDenoiser(torch.nn.Module): assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)" + # If we use masks, blending between the denoised and original latent images occurs here. + def apply_blend(current_latent): + blended_latent = current_latent * self.nmask + self.init_latent * self.mask + + if self.p.scripts is not None: + from modules import scripts + mba = scripts.MaskBlendArgs(current_latent, self.nmask, self.init_latent, self.mask, blended_latent, denoiser=self, sigma=sigma) + self.p.scripts.on_mask_blend(self.p, mba) + blended_latent = mba.blended_latent + + return blended_latent + + # Blend in the original latents (before) if self.mask_before_denoising and self.mask is not None: - x = self.init_latent * self.mask + self.nmask * x + x = apply_blend(x) batch_size = len(conds_list) repeats = [len(conds_list[i]) for i in range(batch_size)] @@ -207,8 +223,9 @@ class CFGDenoiser(torch.nn.Module): else: denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale) + # Blend in the original latents (after) if not self.mask_before_denoising and self.mask is not None: - denoised = self.init_latent * self.mask + self.nmask * denoised + denoised = apply_blend(denoised) self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma) diff --git a/modules/sd_samplers_extra.py b/modules/sd_samplers_extra.py index 1b981ca80..72fd0aa5e 100644 --- a/modules/sd_samplers_extra.py +++ b/modules/sd_samplers_extra.py @@ -60,7 +60,7 @@ def restart_sampler(model, x, sigmas, extra_args=None, callback=None, disable=No sigma_restart = get_sigmas_karras(restart_steps, sigmas[min_idx].item(), sigmas[max_idx].item(), device=sigmas.device)[:-1] while restart_times > 0: restart_times -= 1 - step_list.extend([(old_sigma, new_sigma) for (old_sigma, new_sigma) in zip(sigma_restart[:-1], sigma_restart[1:])]) + step_list.extend(zip(sigma_restart[:-1], sigma_restart[1:])) last_sigma = None for old_sigma, new_sigma in tqdm.tqdm(step_list, disable=disable): diff --git a/modules/sd_samplers_timesteps.py b/modules/sd_samplers_timesteps.py index b17a8f93c..777dd8d0e 100644 --- a/modules/sd_samplers_timesteps.py +++ b/modules/sd_samplers_timesteps.py @@ -36,7 +36,7 @@ class CompVisTimestepsVDenoiser(torch.nn.Module): self.inner_model = model def predict_eps_from_z_and_v(self, x_t, t, v): - return self.inner_model.sqrt_alphas_cumprod[t.to(torch.int), None, None, None] * v + self.inner_model.sqrt_one_minus_alphas_cumprod[t.to(torch.int), None, None, None] * x_t + return torch.sqrt(self.inner_model.alphas_cumprod)[t.to(torch.int), None, None, None] * v + torch.sqrt(1 - self.inner_model.alphas_cumprod)[t.to(torch.int), None, None, None] * x_t def forward(self, input, timesteps, **kwargs): model_output = self.inner_model.apply_model(input, timesteps, **kwargs) @@ -80,6 +80,7 @@ class CompVisSampler(sd_samplers_common.Sampler): self.eta_default = 0.0 self.model_wrap_cfg = CFGDenoiserTimesteps(self) + self.model_wrap = self.model_wrap_cfg.inner_model def get_timesteps(self, p, steps): discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False) diff --git a/modules/sd_samplers_timesteps_impl.py b/modules/sd_samplers_timesteps_impl.py index a72daafd4..930a64af5 100644 --- a/modules/sd_samplers_timesteps_impl.py +++ b/modules/sd_samplers_timesteps_impl.py @@ -11,7 +11,7 @@ from modules.models.diffusion.uni_pc import uni_pc def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0): alphas_cumprod = model.inner_model.inner_model.alphas_cumprod alphas = alphas_cumprod[timesteps] - alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' else torch.float32) + alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' and x.device.type != 'xpu' else torch.float32) sqrt_one_minus_alphas = torch.sqrt(1 - alphas) sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy())) @@ -43,7 +43,7 @@ def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta= def plms(model, x, timesteps, extra_args=None, callback=None, disable=None): alphas_cumprod = model.inner_model.inner_model.alphas_cumprod alphas = alphas_cumprod[timesteps] - alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' else torch.float32) + alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' and x.device.type != 'xpu' else torch.float32) sqrt_one_minus_alphas = torch.sqrt(1 - alphas) extra_args = {} if extra_args is None else extra_args diff --git a/modules/sd_unet.py b/modules/sd_unet.py index 6a7bc9e26..a771849c8 100644 --- a/modules/sd_unet.py +++ b/modules/sd_unet.py @@ -5,8 +5,7 @@ from modules import script_callbacks, shared, devices unet_options = [] current_unet_option = None current_unet = None -original_forward = None - +original_forward = None # not used, only left temporarily for compatibility def list_unets(): new_unets = script_callbacks.list_unets_callback() @@ -84,9 +83,12 @@ class SdUnet(torch.nn.Module): pass -def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs): - if current_unet is not None: - return current_unet.forward(x, timesteps, context, *args, **kwargs) +def create_unet_forward(original_forward): + def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs): + if current_unet is not None: + return current_unet.forward(x, timesteps, context, *args, **kwargs) - return original_forward(self, x, timesteps, context, *args, **kwargs) + return original_forward(self, x, timesteps, context, *args, **kwargs) + + return UNetModel_forward diff --git a/modules/shared_gradio_themes.py b/modules/shared_gradio_themes.py index 822db0a95..b6dc31450 100644 --- a/modules/shared_gradio_themes.py +++ b/modules/shared_gradio_themes.py @@ -65,3 +65,7 @@ def reload_gradio_theme(theme_name=None): except Exception as e: errors.display(e, "changing gradio theme") shared.gradio_theme = gr.themes.Default(**default_theme_args) + + # append additional values gradio_theme + shared.gradio_theme.sd_webui_modal_lightbox_toolbar_opacity = shared.opts.sd_webui_modal_lightbox_toolbar_opacity + shared.gradio_theme.sd_webui_modal_lightbox_icon_opacity = shared.opts.sd_webui_modal_lightbox_icon_opacity diff --git a/modules/shared_items.py b/modules/shared_items.py index b1459f8c4..e13924720 100644 --- a/modules/shared_items.py +++ b/modules/shared_items.py @@ -66,7 +66,25 @@ def reload_hypernetworks(): shared.hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir) +def get_infotext_names(): + from modules import infotext, shared + res = {} + + for info in shared.opts.data_labels.values(): + if info.infotext: + res[info.infotext] = 1 + + for tab_data in infotext.paste_fields.values(): + for _, name in tab_data.get("fields") or []: + if isinstance(name, str): + res[name] = 1 + + return list(res) + + ui_reorder_categories_builtin_items = [ + "prompt", + "image", "inpaint", "sampler", "accordions", diff --git a/modules/shared_options.py b/modules/shared_options.py index 0a82216ff..cca3f7be3 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -1,9 +1,10 @@ +import os import gradio as gr -from modules import localization, ui_components, shared_items, shared, interrogate, shared_gradio_themes -from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401 +from modules import localization, ui_components, shared_items, shared, interrogate, shared_gradio_themes, util +from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir, default_output_dir # noqa: F401 from modules.shared_cmd_options import cmd_opts -from modules.options import options_section, OptionInfo, OptionHTML +from modules.options import options_section, OptionInfo, OptionHTML, categories options_templates = {} hide_dirs = shared.hide_dirs @@ -21,7 +22,14 @@ restricted_opts = { "outdir_init_images" } -options_templates.update(options_section(('saving-images', "Saving images/grids"), { +categories.register_category("saving", "Saving images") +categories.register_category("sd", "Stable Diffusion") +categories.register_category("ui", "User Interface") +categories.register_category("system", "System") +categories.register_category("postprocessing", "Postprocessing") +categories.register_category("training", "Training") + +options_templates.update(options_section(('saving-images', "Saving images/grids", "saving"), { "samples_save": OptionInfo(True, "Always save all generated images"), "samples_format": OptionInfo('png', 'File format for images'), "samples_filename_pattern": OptionInfo("", "Images filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"), @@ -39,8 +47,6 @@ options_templates.update(options_section(('saving-images', "Saving images/grids" "grid_text_inactive_color": OptionInfo("#999999", "Inactive text color for image grids", ui_components.FormColorPicker, {}), "grid_background_color": OptionInfo("#ffffff", "Background color for image grids", ui_components.FormColorPicker, {}), - "enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"), - "save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."), "save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."), "save_images_before_highres_fix": OptionInfo(False, "Save a copy of image before applying highres fix."), "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), @@ -64,21 +70,22 @@ options_templates.update(options_section(('saving-images', "Saving images/grids" "save_incomplete_images": OptionInfo(False, "Save incomplete images").info("save images that has been interrupted in mid-generation; even if not saved, they will still show up in webui output."), "notification_audio": OptionInfo(True, "Play notification sound after image generation").info("notification.mp3 should be present in the root directory").needs_reload_ui(), + "notification_volume": OptionInfo(100, "Notification sound volume", gr.Slider, {"minimum": 0, "maximum": 100, "step": 1}).info("in %"), })) -options_templates.update(options_section(('saving-paths', "Paths for saving"), { +options_templates.update(options_section(('saving-paths', "Paths for saving", "saving"), { "outdir_samples": OptionInfo("", "Output directory for images; if empty, defaults to three directories below", component_args=hide_dirs), - "outdir_txt2img_samples": OptionInfo("outputs/txt2img-images", 'Output directory for txt2img images', component_args=hide_dirs), - "outdir_img2img_samples": OptionInfo("outputs/img2img-images", 'Output directory for img2img images', component_args=hide_dirs), - "outdir_extras_samples": OptionInfo("outputs/extras-images", 'Output directory for images from extras tab', component_args=hide_dirs), + "outdir_txt2img_samples": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'txt2img-images')), 'Output directory for txt2img images', component_args=hide_dirs), + "outdir_img2img_samples": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'img2img-images')), 'Output directory for img2img images', component_args=hide_dirs), + "outdir_extras_samples": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'extras-images')), 'Output directory for images from extras tab', component_args=hide_dirs), "outdir_grids": OptionInfo("", "Output directory for grids; if empty, defaults to two directories below", component_args=hide_dirs), - "outdir_txt2img_grids": OptionInfo("outputs/txt2img-grids", 'Output directory for txt2img grids', component_args=hide_dirs), - "outdir_img2img_grids": OptionInfo("outputs/img2img-grids", 'Output directory for img2img grids', component_args=hide_dirs), - "outdir_save": OptionInfo("log/images", "Directory for saving images using the Save button", component_args=hide_dirs), - "outdir_init_images": OptionInfo("outputs/init-images", "Directory for saving init images when using img2img", component_args=hide_dirs), + "outdir_txt2img_grids": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'txt2img-grids')), 'Output directory for txt2img grids', component_args=hide_dirs), + "outdir_img2img_grids": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'img2img-grids')), 'Output directory for img2img grids', component_args=hide_dirs), + "outdir_save": OptionInfo(util.truncate_path(os.path.join(data_path, 'log', 'images')), "Directory for saving images using the Save button", component_args=hide_dirs), + "outdir_init_images": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'init-images')), "Directory for saving init images when using img2img", component_args=hide_dirs), })) -options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), { +options_templates.update(options_section(('saving-to-dirs', "Saving to a directory", "saving"), { "save_to_dirs": OptionInfo(True, "Save images to a subdirectory"), "grid_save_to_dirs": OptionInfo(True, "Save grids to a subdirectory"), "use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"), @@ -86,21 +93,21 @@ options_templates.update(options_section(('saving-to-dirs', "Saving to a directo "directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1, **hide_dirs}), })) -options_templates.update(options_section(('upscaling', "Upscaling"), { +options_templates.update(options_section(('upscaling', "Upscaling", "postprocessing"), { "ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"), "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"), "realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}), "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in shared.sd_upscalers]}), })) -options_templates.update(options_section(('face-restoration', "Face restoration"), { +options_templates.update(options_section(('face-restoration', "Face restoration", "postprocessing"), { "face_restoration": OptionInfo(False, "Restore faces", infotext='Face restoration').info("will use a third-party model on generation result to reconstruct faces"), "face_restoration_model": OptionInfo("CodeFormer", "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in shared.face_restorers]}), "code_former_weight": OptionInfo(0.5, "CodeFormer weight", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}).info("0 = maximum effect; 1 = minimum effect"), "face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"), })) -options_templates.update(options_section(('system', "System"), { +options_templates.update(options_section(('system', "System", "system"), { "auto_launch_browser": OptionInfo("Local", "Automatically open webui in browser on startup", gr.Radio, lambda: {"choices": ["Disable", "Local", "Remote"]}), "enable_console_prompts": OptionInfo(shared.cmd_opts.enable_console_prompts, "Print prompts to console when generating with txt2img and img2img."), "show_warnings": OptionInfo(False, "Show warnings in console.").needs_reload_ui(), @@ -115,13 +122,13 @@ options_templates.update(options_section(('system', "System"), { "dump_stacks_on_signal": OptionInfo(False, "Print stack traces before exiting the program with ctrl+c."), })) -options_templates.update(options_section(('API', "API"), { +options_templates.update(options_section(('API', "API", "system"), { "api_enable_requests": OptionInfo(True, "Allow http:// and https:// URLs for input images in API", restrict_api=True), "api_forbid_local_requests": OptionInfo(True, "Forbid URLs to local resources", restrict_api=True), "api_useragent": OptionInfo("", "User agent for requests", restrict_api=True), })) -options_templates.update(options_section(('training', "Training"), { +options_templates.update(options_section(('training', "Training", "training"), { "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."), "pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."), "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."), @@ -136,7 +143,7 @@ options_templates.update(options_section(('training', "Training"), { "training_tensorboard_flush_every": OptionInfo(120, "How often, in seconds, to flush the pending tensorboard events and summaries to disk."), })) -options_templates.update(options_section(('sd', "Stable Diffusion"), { +options_templates.update(options_section(('sd', "Stable Diffusion", "sd"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": shared_items.list_checkpoint_tiles(shared.opts.sd_checkpoint_dropdown_use_short)}, refresh=shared_items.refresh_checkpoints, infotext='Model hash'), "sd_checkpoints_limit": OptionInfo(1, "Maximum number of checkpoints loaded at the same time", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}), "sd_checkpoints_keep_in_cpu": OptionInfo(True, "Only keep one model on device").info("will keep models other than the currently used one in RAM rather than VRAM"), @@ -153,14 +160,14 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "hires_fix_refiner_pass": OptionInfo("second pass", "Hires fix: which pass to enable refiner for", gr.Radio, {"choices": ["first pass", "second pass", "both passes"]}, infotext="Hires refiner"), })) -options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), { +options_templates.update(options_section(('sdxl', "Stable Diffusion XL", "sd"), { "sdxl_crop_top": OptionInfo(0, "crop top coordinate"), "sdxl_crop_left": OptionInfo(0, "crop left coordinate"), "sdxl_refiner_low_aesthetic_score": OptionInfo(2.5, "SDXL low aesthetic score", gr.Number).info("used for refiner model negative prompt"), "sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"), })) -options_templates.update(options_section(('vae', "VAE"), { +options_templates.update(options_section(('vae', "VAE", "sd"), { "sd_vae_explanation": OptionHTML(""" VAE is a neural network that transforms a standard RGB image into latent space representation and back. Latent space representation is what stable diffusion is working on during sampling @@ -170,12 +177,13 @@ For img2img, VAE is used to process user's input image before the sampling, and "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list, infotext='VAE').info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"), "sd_vae_overrides_per_model_preferences": OptionInfo(True, "Selected VAE overrides per-model preferences").info("you can set per-model VAE either by editing user metadata for checkpoints, or by making the VAE have same name as checkpoint"), + "auto_vae_precision_bfloat16": OptionInfo(False, "Automatically convert VAE to bfloat16").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image; if enabled, overrides the option below"), "auto_vae_precision": OptionInfo(True, "Automatically revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"), "sd_vae_encode_method": OptionInfo("Full", "VAE type for encode", gr.Radio, {"choices": ["Full", "TAESD"]}, infotext='VAE Encoder').info("method to encode image to latent (use in img2img, hires-fix or inpaint mask)"), "sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD"]}, infotext='VAE Decoder').info("method to decode latent to image"), })) -options_templates.update(options_section(('img2img', "img2img"), { +options_templates.update(options_section(('img2img', "img2img", "sd"), { "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Conditional mask weight'), "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.0, "maximum": 1.5, "step": 0.001}, infotext='Noise multiplier'), "img2img_extra_noise": OptionInfo(0.0, "Extra noise multiplier for img2img and hires fix", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Extra noise').info("0 = disabled (default); should be lower than denoising strength"), @@ -188,9 +196,10 @@ options_templates.update(options_section(('img2img', "img2img"), { "img2img_inpaint_sketch_default_brush_color": OptionInfo("#ffffff", "Inpaint sketch initial brush color", ui_components.FormColorPicker, {}).info("default brush color of img2img inpaint sketch").needs_reload_ui(), "return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"), "return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"), + "img2img_batch_show_results_limit": OptionInfo(32, "Show the first N batch img2img results in UI", gr.Slider, {"minimum": -1, "maximum": 1000, "step": 1}).info('0: disable, -1: show all images. Too many images can cause lag'), })) -options_templates.update(options_section(('optimizations', "Optimizations"), { +options_templates.update(options_section(('optimizations', "Optimizations", "sd"), { "cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}), "s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 15.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"), "token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}, infotext='Token merging ratio').link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"), @@ -199,9 +208,12 @@ options_templates.update(options_section(('optimizations', "Optimizations"), { "pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length", infotext='Pad conds').info("improves performance when prompt and negative prompt have different lengths; changes seeds"), "persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("do not recalculate conds from prompts if prompts have not changed since previous calculation"), "batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"), + "fp8_storage": OptionInfo("Disable", "FP8 weight", gr.Radio, {"choices": ["Disable", "Enable for SDXL", "Enable"]}).info("Use FP8 to store Linear/Conv layers' weight. Require pytorch>=2.1.0."), + "cache_fp16_weight": OptionInfo(False, "Cache FP16 weight for LoRA").info("Cache fp16 weight when enabling FP8, will increase the quality of LoRA. Use more system ram."), })) -options_templates.update(options_section(('compatibility', "Compatibility"), { +options_templates.update(options_section(('compatibility', "Compatibility", "sd"), { + "auto_backcompat": OptionInfo(True, "Automatic backward compatibility").info("automatically enable options for backwards compatibility when importing generation parameters from infotext that has program version."), "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."), "use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."), "no_dpmpp_sde_batch_determinism": OptionInfo(False, "Do not make DPM++ SDE deterministic across different batch sizes."), @@ -209,6 +221,7 @@ options_templates.update(options_section(('compatibility', "Compatibility"), { "dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."), "hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."), "use_old_scheduling": OptionInfo(False, "Use old prompt editing timelines.", infotext="Old prompt editing timelines").info("For [red:green:N]; old: If N < 1, it's a fraction of steps (and hires fix uses range from 0 to 1), if N >= 1, it's an absolute number of steps; new: If N has a decimal point in it, it's a fraction of steps (and hires fix uses range from 1 to 2), othewrwise it's an absolute number of steps"), + "use_downcasted_alpha_bar": OptionInfo(False, "Downcast model alphas_cumprod to fp16 before sampling. For reproducing old seeds.", infotext="Downcast alphas_cumprod") })) options_templates.update(options_section(('interrogate', "Interrogate"), { @@ -226,14 +239,17 @@ options_templates.update(options_section(('interrogate', "Interrogate"), { "deepbooru_filter_tags": OptionInfo("", "deepbooru: filter out those tags").info("separate by comma"), })) -options_templates.update(options_section(('extra_networks', "Extra Networks"), { +options_templates.update(options_section(('extra_networks', "Extra Networks", "sd"), { "extra_networks_show_hidden_directories": OptionInfo(True, "Show hidden directories").info("directory is hidden if its name starts with \".\"."), + "extra_networks_dir_button_function": OptionInfo(False, "Add a '/' to the beginning of directory buttons").info("Buttons will display the contents of the selected directory without acting as a search filter."), "extra_networks_hidden_models": OptionInfo("When searched", "Show cards for models in hidden directories", gr.Radio, {"choices": ["Always", "When searched", "Never"]}).info('"When searched" option will only show the item when the search string has 4 characters or more'), "extra_networks_default_multiplier": OptionInfo(1.0, "Default multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}), "extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks").info("in pixels"), "extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks").info("in pixels"), "extra_networks_card_text_scale": OptionInfo(1.0, "Card text scale", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}).info("1 = original size"), "extra_networks_card_show_desc": OptionInfo(True, "Show description on card"), + "extra_networks_card_order_field": OptionInfo("Path", "Default order field for Extra Networks cards", gr.Dropdown, {"choices": ['Path', 'Name', 'Date Created', 'Date Modified']}).needs_reload_ui(), + "extra_networks_card_order": OptionInfo("Ascending", "Default order for Extra Networks cards", gr.Dropdown, {"choices": ['Ascending', 'Descending']}).needs_reload_ui(), "extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"), "ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_reload_ui(), "textual_inversion_print_at_load": OptionInfo(False, "Print a list of Textual Inversion embeddings when loading model"), @@ -241,44 +257,69 @@ options_templates.update(options_section(('extra_networks', "Extra Networks"), { "sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", *shared.hypernetworks]}, refresh=shared_items.reload_hypernetworks), })) -options_templates.update(options_section(('ui', "User interface"), { - "localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_reload_ui(), - "gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + shared_gradio_themes.gradio_hf_hub_themes}).info("you can also manually enter any of themes from the gallery.").needs_reload_ui(), - "gradio_themes_cache": OptionInfo(True, "Cache gradio themes locally").info("disable to update the selected Gradio theme"), - "gallery_height": OptionInfo("", "Gallery height", gr.Textbox).info("an be any valid CSS value").needs_reload_ui(), - "return_grid": OptionInfo(True, "Show grid in results for web"), - "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), - "send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"), - "send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"), - "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), - "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), - "js_modal_lightbox_gamepad": OptionInfo(False, "Navigate image viewer with gamepad"), - "js_modal_lightbox_gamepad_repeat": OptionInfo(250, "Gamepad repeat period, in milliseconds"), - "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."), - "samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group").needs_reload_ui(), - "dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row").needs_reload_ui(), - "keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), - "keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing ", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), - "keyedit_delimiters": OptionInfo(r".,\/!?%^*;:{}=`~() ", "Ctrl+up/down word delimiters"), +options_templates.update(options_section(('ui_prompt_editing', "Prompt editing", "ui"), { + "keyedit_precision_attention": OptionInfo(0.1, "Precision for (attention:1.1) when editing the prompt with Ctrl+up/down", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), + "keyedit_precision_extra": OptionInfo(0.05, "Precision for when editing the prompt with Ctrl+up/down", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), + "keyedit_delimiters": OptionInfo(r".,\/!?%^*;:{}=`~() ", "Word delimiters when editing the prompt with Ctrl+up/down"), "keyedit_delimiters_whitespace": OptionInfo(["Tab", "Carriage Return", "Line Feed"], "Ctrl+up/down whitespace delimiters", gr.CheckboxGroup, lambda: {"choices": ["Tab", "Carriage Return", "Line Feed"]}), "keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"), - "quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_reload_ui(), - "ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(), - "hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(), - "ui_reorder_list": OptionInfo([], "txt2img/img2img UI item order", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_reload_ui(), - "sd_checkpoint_dropdown_use_short": OptionInfo(False, "Checkpoint dropdown: use filenames without paths").info("models in subdirectories like photo/sd15.ckpt will be listed as just sd15.ckpt"), - "hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires checkpoint and sampler selection").needs_reload_ui(), - "hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_reload_ui(), "disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_reload_ui(), })) +options_templates.update(options_section(('ui_gallery', "Gallery", "ui"), { + "return_grid": OptionInfo(True, "Show grid in gallery"), + "do_not_show_images": OptionInfo(False, "Do not show any images in gallery"), + "js_modal_lightbox": OptionInfo(True, "Full page image viewer: enable"), + "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Full page image viewer: show images zoomed in by default"), + "js_modal_lightbox_gamepad": OptionInfo(False, "Full page image viewer: navigate with gamepad"), + "js_modal_lightbox_gamepad_repeat": OptionInfo(250, "Full page image viewer: gamepad repeat period").info("in milliseconds"), + "sd_webui_modal_lightbox_icon_opacity": OptionInfo(1, "Full page image viewer: control icon unfocused opacity", gr.Slider, {"minimum": 0.0, "maximum": 1, "step": 0.01}, onchange=shared.reload_gradio_theme).info('for mouse only').needs_reload_ui(), + "sd_webui_modal_lightbox_toolbar_opacity": OptionInfo(0.9, "Full page image viewer: tool bar opacity", gr.Slider, {"minimum": 0.0, "maximum": 1, "step": 0.01}, onchange=shared.reload_gradio_theme).info('for mouse only').needs_reload_ui(), + "gallery_height": OptionInfo("", "Gallery height", gr.Textbox).info("can be any valid CSS value, for example 768px or 20em").needs_reload_ui(), +})) -options_templates.update(options_section(('infotext', "Infotext"), { - "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), - "add_model_name_to_info": OptionInfo(True, "Add model name to generation information"), - "add_user_name_to_info": OptionInfo(False, "Add user name to generation information when authenticated"), - "add_version_to_infotext": OptionInfo(True, "Add program version to generation information"), +options_templates.update(options_section(('ui_alternatives', "UI alternatives", "ui"), { + "compact_prompt_box": OptionInfo(False, "Compact prompt layout").info("puts prompt and negative prompt inside the Generate tab, leaving more vertical space for the image on the right").needs_reload_ui(), + "samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group").needs_reload_ui(), + "dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row").needs_reload_ui(), + "sd_checkpoint_dropdown_use_short": OptionInfo(False, "Checkpoint dropdown: use filenames without paths").info("models in subdirectories like photo/sd15.ckpt will be listed as just sd15.ckpt"), + "hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires checkpoint and sampler selection").needs_reload_ui(), + "hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_reload_ui(), + "txt2img_settings_accordion": OptionInfo(False, "Settings in txt2img hidden under Accordion").needs_reload_ui(), + "img2img_settings_accordion": OptionInfo(False, "Settings in img2img hidden under Accordion").needs_reload_ui(), + "interrupt_after_current": OptionInfo(True, "Don't Interrupt in the middle").info("when using Interrupt button, if generating more than one image, stop after the generation of an image has finished, instead of immediately"), +})) + +options_templates.update(options_section(('ui', "User interface", "ui"), { + "localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_reload_ui(), + "quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_reload_ui(), + "ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(), + "hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(), + "ui_reorder_list": OptionInfo([], "UI item order for txt2img/img2img tabs", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_reload_ui(), + "gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + shared_gradio_themes.gradio_hf_hub_themes}).info("you can also manually enter any of themes from the gallery.").needs_reload_ui(), + "gradio_themes_cache": OptionInfo(True, "Cache gradio themes locally").info("disable to update the selected Gradio theme"), + "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."), + "send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"), + "send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"), +})) + + +options_templates.update(options_section(('infotext', "Infotext", "ui"), { + "infotext_explanation": OptionHTML(""" +Infotext is what this software calls the text that contains generation parameters and can be used to generate the same picture again. +It is displayed in UI below the image. To use infotext, paste it into the prompt and click the ↙️ paste button. +"""), + "enable_pnginfo": OptionInfo(True, "Write infotext to metadata of the generated image"), + "save_txt": OptionInfo(False, "Create a text file with infotext next to every generated image"), + + "add_model_name_to_info": OptionInfo(True, "Add model name to infotext"), + "add_model_hash_to_info": OptionInfo(True, "Add model hash to infotext"), + "add_vae_name_to_info": OptionInfo(True, "Add VAE name to infotext"), + "add_vae_hash_to_info": OptionInfo(True, "Add VAE hash to infotext"), + "add_user_name_to_info": OptionInfo(False, "Add user name to infotext when authenticated"), + "add_version_to_infotext": OptionInfo(True, "Add program version to infotext"), "disable_weights_auto_swap": OptionInfo(True, "Disregard checkpoint information from pasted infotext").info("when reading generation parameters from text into UI"), + "infotext_skip_pasting": OptionInfo([], "Disregard fields from pasted infotext", ui_components.DropdownMulti, lambda: {"choices": shared_items.get_infotext_names()}), "infotext_styles": OptionInfo("Apply if any", "Infer styles from prompts of pasted infotext", gr.Radio, {"choices": ["Ignore", "Apply", "Discard", "Apply if any"]}).info("when reading generation parameters from text into UI)").html("""
  • Ignore: keep prompt and styles dropdown as it is.
  • Apply: remove style text from prompt, always replace styles dropdown value with found styles (even if none are found).
  • @@ -288,7 +329,7 @@ options_templates.update(options_section(('infotext', "Infotext"), { })) -options_templates.update(options_section(('ui', "Live previews"), { +options_templates.update(options_section(('ui', "Live previews", "ui"), { "show_progressbar": OptionInfo(True, "Show progressbar"), "live_previews_enable": OptionInfo(True, "Show live previews of the created image"), "live_previews_image_format": OptionInfo("png", "Live preview file format", gr.Radio, {"choices": ["jpeg", "png", "webp"]}), @@ -299,9 +340,10 @@ options_templates.update(options_section(('ui', "Live previews"), { "live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}), "live_preview_refresh_period": OptionInfo(1000, "Progressbar and preview update period").info("in milliseconds"), "live_preview_fast_interrupt": OptionInfo(False, "Return image with chosen live preview method on interrupt").info("makes interrupts faster"), + "js_live_preview_in_modal_lightbox": OptionInfo(False, "Show Live preview in full page image viewer"), })) -options_templates.update(options_section(('sampler-params', "Sampler parameters"), { +options_templates.update(options_section(('sampler-params', "Sampler parameters", "sd"), { "hide_samplers": OptionInfo([], "Hide samplers in user interface", gr.CheckboxGroup, lambda: {"choices": [x.name for x in shared_items.list_samplers()]}).needs_reload_ui(), "eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta DDIM').info("noise multiplier; higher = more unpredictable results"), "eta_ancestral": OptionInfo(1.0, "Eta for k-diffusion samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta').info("noise multiplier; currently only applies to ancestral samplers (i.e. Euler a) and SDE samplers"), @@ -321,12 +363,14 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" 'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}, infotext='UniPC skip type'), 'uni_pc_order': OptionInfo(3, "UniPC order", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}, infotext='UniPC order').info("must be < sampling steps"), 'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final", infotext='UniPC lower order final'), + 'sd_noise_schedule': OptionInfo("Default", "Noise schedule for sampling", gr.Radio, {"choices": ["Default", "Zero Terminal SNR"]}, infotext="Noise Schedule").info("for use with zero terminal SNR trained models") })) -options_templates.update(options_section(('postprocessing', "Postprocessing"), { +options_templates.update(options_section(('postprocessing', "Postprocessing", "postprocessing"), { 'postprocessing_enable_in_main_ui': OptionInfo([], "Enable postprocessing operations in txt2img and img2img tabs", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}), 'postprocessing_operation_order': OptionInfo([], "Postprocessing operation order", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}), 'upscaling_max_images_in_cache': OptionInfo(5, "Maximum number of images in upscaling cache", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), + 'postprocessing_existing_caption_action': OptionInfo("Ignore", "Action for existing captions", gr.Radio, {"choices": ["Ignore", "Keep", "Prepend", "Append"]}).info("when generating captions using postprocessing; Ignore = use generated; Keep = use original; Prepend/Append = combine both"), })) options_templates.update(options_section((None, "Hidden options"), { diff --git a/modules/shared_state.py b/modules/shared_state.py index a68789cc8..33996691c 100644 --- a/modules/shared_state.py +++ b/modules/shared_state.py @@ -12,6 +12,7 @@ log = logging.getLogger(__name__) class State: skipped = False interrupted = False + stopping_generation = False job = "" job_no = 0 job_count = 0 @@ -79,6 +80,10 @@ class State: self.interrupted = True log.info("Received interrupt request") + def stop_generating(self): + self.stopping_generation = True + log.info("Received stop generating request") + def nextjob(self): if shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps == -1: self.do_set_current_image() @@ -91,6 +96,7 @@ class State: obj = { "skipped": self.skipped, "interrupted": self.interrupted, + "stopping_generation": self.stopping_generation, "job": self.job, "job_count": self.job_count, "job_timestamp": self.job_timestamp, @@ -114,6 +120,7 @@ class State: self.id_live_preview = 0 self.skipped = False self.interrupted = False + self.stopping_generation = False self.textinfo = None self.job = job devices.torch_gc() diff --git a/modules/styles.py b/modules/styles.py index 0740fe1b1..026c43001 100644 --- a/modules/styles.py +++ b/modules/styles.py @@ -1,7 +1,7 @@ import csv +import fnmatch import os import os.path -import re import typing import shutil @@ -10,6 +10,7 @@ class PromptStyle(typing.NamedTuple): name: str prompt: str negative_prompt: str + path: str = None def merge_prompts(style_prompt: str, prompt: str) -> str: @@ -29,12 +30,17 @@ def apply_styles_to_prompt(prompt, styles): return prompt -re_spaces = re.compile(" +") - - def extract_style_text_from_prompt(style_text, prompt): - stripped_prompt = re.sub(re_spaces, " ", prompt.strip()) - stripped_style_text = re.sub(re_spaces, " ", style_text.strip()) + """This function extracts the text from a given prompt based on a provided style text. It checks if the style text contains the placeholder {prompt} or if it appears at the end of the prompt. If a match is found, it returns True along with the extracted text. Otherwise, it returns False and the original prompt. + + extract_style_text_from_prompt("masterpiece", "1girl, art by greg, masterpiece") outputs (True, "1girl, art by greg") + extract_style_text_from_prompt("masterpiece, {prompt}", "masterpiece, 1girl, art by greg") outputs (True, "1girl, art by greg") + extract_style_text_from_prompt("masterpiece, {prompt}", "exquisite, 1girl, art by greg") outputs (False, "exquisite, 1girl, art by greg") + """ + + stripped_prompt = prompt.strip() + stripped_style_text = style_text.strip() + if "{prompt}" in stripped_style_text: left, right = stripped_style_text.split("{prompt}", 2) if stripped_prompt.startswith(left) and stripped_prompt.endswith(right): @@ -52,7 +58,12 @@ def extract_style_text_from_prompt(style_text, prompt): return False, prompt -def extract_style_from_prompts(style: PromptStyle, prompt, negative_prompt): +def extract_original_prompts(style: PromptStyle, prompt, negative_prompt): + """ + Takes a style and compares it to the prompt and negative prompt. If the style + matches, returns True plus the prompt and negative prompt with the style text + removed. Otherwise, returns False with the original prompt and negative prompt. + """ if not style.prompt and not style.negative_prompt: return False, prompt, negative_prompt @@ -69,25 +80,84 @@ def extract_style_from_prompts(style: PromptStyle, prompt, negative_prompt): class StyleDatabase: def __init__(self, path: str): - self.no_style = PromptStyle("None", "", "") + self.no_style = PromptStyle("None", "", "", None) self.styles = {} self.path = path + folder, file = os.path.split(self.path) + filename, _, ext = file.partition('*') + self.default_path = os.path.join(folder, filename + ext) + + self.prompt_fields = [field for field in PromptStyle._fields if field != "path"] + self.reload() def reload(self): + """ + Clears the style database and reloads the styles from the CSV file(s) + matching the path used to initialize the database. + """ self.styles.clear() - if not os.path.exists(self.path): - return + path, filename = os.path.split(self.path) - with open(self.path, "r", encoding="utf-8-sig", newline='') as file: + if "*" in filename: + fileglob = filename.split("*")[0] + "*.csv" + filelist = [] + for file in os.listdir(path): + if fnmatch.fnmatch(file, fileglob): + filelist.append(file) + # Add a visible divider to the style list + half_len = round(len(file) / 2) + divider = f"{'-' * (20 - half_len)} {file.upper()}" + divider = f"{divider} {'-' * (40 - len(divider))}" + self.styles[divider] = PromptStyle( + f"{divider}", None, None, "do_not_save" + ) + # Add styles from this CSV file + self.load_from_csv(os.path.join(path, file)) + if len(filelist) == 0: + print(f"No styles found in {path} matching {fileglob}") + return + elif not os.path.exists(self.path): + print(f"Style database not found: {self.path}") + return + else: + self.load_from_csv(self.path) + + def load_from_csv(self, path: str): + with open(path, "r", encoding="utf-8-sig", newline="") as file: reader = csv.DictReader(file, skipinitialspace=True) for row in reader: + # Ignore empty rows or rows starting with a comment + if not row or row["name"].startswith("#"): + continue # Support loading old CSV format with "name, text"-columns prompt = row["prompt"] if "prompt" in row else row["text"] negative_prompt = row.get("negative_prompt", "") - self.styles[row["name"]] = PromptStyle(row["name"], prompt, negative_prompt) + # Add style to database + self.styles[row["name"]] = PromptStyle( + row["name"], prompt, negative_prompt, path + ) + + def get_style_paths(self) -> set: + """Returns a set of all distinct paths of files that styles are loaded from.""" + # Update any styles without a path to the default path + for style in list(self.styles.values()): + if not style.path: + self.styles[style.name] = style._replace(path=self.default_path) + + # Create a list of all distinct paths, including the default path + style_paths = set() + style_paths.add(self.default_path) + for _, style in self.styles.items(): + if style.path: + style_paths.add(style.path) + + # Remove any paths for styles that are just list dividers + style_paths.discard("do_not_save") + + return style_paths def get_style_prompts(self, styles): return [self.styles.get(x, self.no_style).prompt for x in styles] @@ -96,20 +166,40 @@ class StyleDatabase: return [self.styles.get(x, self.no_style).negative_prompt for x in styles] def apply_styles_to_prompt(self, prompt, styles): - return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).prompt for x in styles]) + return apply_styles_to_prompt( + prompt, [self.styles.get(x, self.no_style).prompt for x in styles] + ) def apply_negative_styles_to_prompt(self, prompt, styles): - return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles]) + return apply_styles_to_prompt( + prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles] + ) - def save_styles(self, path: str) -> None: - # Always keep a backup file around - if os.path.exists(path): - shutil.copy(path, f"{path}.bak") + def save_styles(self, path: str = None) -> None: + # The path argument is deprecated, but kept for backwards compatibility + _ = path - with open(path, "w", encoding="utf-8-sig", newline='') as file: - writer = csv.DictWriter(file, fieldnames=PromptStyle._fields) - writer.writeheader() - writer.writerows(style._asdict() for k, style in self.styles.items()) + style_paths = self.get_style_paths() + + csv_names = [os.path.split(path)[1].lower() for path in style_paths] + + for style_path in style_paths: + # Always keep a backup file around + if os.path.exists(style_path): + shutil.copy(style_path, f"{style_path}.bak") + + # Write the styles to the CSV file + with open(style_path, "w", encoding="utf-8-sig", newline="") as file: + writer = csv.DictWriter(file, fieldnames=self.prompt_fields) + writer.writeheader() + for style in (s for s in self.styles.values() if s.path == style_path): + # Skip style list dividers, e.g. "STYLES.CSV" + if style.name.lower().strip("# ") in csv_names: + continue + # Write style fields, ignoring the path field + writer.writerow( + {k: v for k, v in style._asdict().items() if k != "path"} + ) def extract_styles_from_prompt(self, prompt, negative_prompt): extracted = [] @@ -120,7 +210,9 @@ class StyleDatabase: found_style = None for style in applicable_styles: - is_match, new_prompt, new_neg_prompt = extract_style_from_prompts(style, prompt, negative_prompt) + is_match, new_prompt, new_neg_prompt = extract_original_prompts( + style, prompt, negative_prompt + ) if is_match: found_style = style prompt = new_prompt diff --git a/modules/sysinfo.py b/modules/sysinfo.py index 2db7551dc..5abf616b7 100644 --- a/modules/sysinfo.py +++ b/modules/sysinfo.py @@ -1,7 +1,6 @@ import json import os import sys -import traceback import platform import hashlib @@ -27,11 +26,9 @@ environment_whitelist = { "OPENCLIP_PACKAGE", "STABLE_DIFFUSION_REPO", "K_DIFFUSION_REPO", - "CODEFORMER_REPO", "BLIP_REPO", "STABLE_DIFFUSION_COMMIT_HASH", "K_DIFFUSION_COMMIT_HASH", - "CODEFORMER_COMMIT_HASH", "BLIP_COMMIT_HASH", "COMMANDLINE_ARGS", "IGNORE_CMD_ARGS_ERRORS", @@ -84,7 +81,7 @@ def get_dict(): "Checksum": checksum_token, "Commandline": get_argv(), "Torch env info": get_torch_sysinfo(), - "Exceptions": get_exceptions(), + "Exceptions": errors.get_exceptions(), "CPU": { "model": platform.processor(), "count logical": psutil.cpu_count(logical=True), @@ -104,21 +101,6 @@ def get_dict(): return res -def format_traceback(tb): - return [[f"{x.filename}, line {x.lineno}, {x.name}", x.line] for x in traceback.extract_tb(tb)] - - -def format_exception(e, tb): - return {"exception": str(e), "traceback": format_traceback(tb)} - - -def get_exceptions(): - try: - return list(reversed(errors.exception_records)) - except Exception as e: - return str(e) - - def get_environment(): return {k: os.environ[k] for k in sorted(os.environ) if k in environment_whitelist} diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py index 1675e39a5..e223a2e0c 100644 --- a/modules/textual_inversion/autocrop.py +++ b/modules/textual_inversion/autocrop.py @@ -3,6 +3,8 @@ import requests import os import numpy as np from PIL import ImageDraw +from modules import paths_internal +from pkg_resources import parse_version GREEN = "#0F0" BLUE = "#00F" @@ -25,7 +27,6 @@ def crop_image(im, settings): elif is_portrait(settings.crop_width, settings.crop_height): scale_by = settings.crop_height / im.height - im = im.resize((int(im.width * scale_by), int(im.height * scale_by))) im_debug = im.copy() @@ -69,6 +70,7 @@ def crop_image(im, settings): return results + def focal_point(im, settings): corner_points = image_corner_points(im, settings) if settings.corner_points_weight > 0 else [] entropy_points = image_entropy_points(im, settings) if settings.entropy_points_weight > 0 else [] @@ -78,118 +80,120 @@ def focal_point(im, settings): weight_pref_total = 0 if corner_points: - weight_pref_total += settings.corner_points_weight + weight_pref_total += settings.corner_points_weight if entropy_points: - weight_pref_total += settings.entropy_points_weight + weight_pref_total += settings.entropy_points_weight if face_points: - weight_pref_total += settings.face_points_weight + weight_pref_total += settings.face_points_weight corner_centroid = None if corner_points: - corner_centroid = centroid(corner_points) - corner_centroid.weight = settings.corner_points_weight / weight_pref_total - pois.append(corner_centroid) + corner_centroid = centroid(corner_points) + corner_centroid.weight = settings.corner_points_weight / weight_pref_total + pois.append(corner_centroid) entropy_centroid = None if entropy_points: - entropy_centroid = centroid(entropy_points) - entropy_centroid.weight = settings.entropy_points_weight / weight_pref_total - pois.append(entropy_centroid) + entropy_centroid = centroid(entropy_points) + entropy_centroid.weight = settings.entropy_points_weight / weight_pref_total + pois.append(entropy_centroid) face_centroid = None if face_points: - face_centroid = centroid(face_points) - face_centroid.weight = settings.face_points_weight / weight_pref_total - pois.append(face_centroid) + face_centroid = centroid(face_points) + face_centroid.weight = settings.face_points_weight / weight_pref_total + pois.append(face_centroid) average_point = poi_average(pois, settings) if settings.annotate_image: - d = ImageDraw.Draw(im) - max_size = min(im.width, im.height) * 0.07 - if corner_centroid is not None: - color = BLUE - box = corner_centroid.bounding(max_size * corner_centroid.weight) - d.text((box[0], box[1]-15), f"Edge: {corner_centroid.weight:.02f}", fill=color) - d.ellipse(box, outline=color) - if len(corner_points) > 1: - for f in corner_points: - d.rectangle(f.bounding(4), outline=color) - if entropy_centroid is not None: - color = "#ff0" - box = entropy_centroid.bounding(max_size * entropy_centroid.weight) - d.text((box[0], box[1]-15), f"Entropy: {entropy_centroid.weight:.02f}", fill=color) - d.ellipse(box, outline=color) - if len(entropy_points) > 1: - for f in entropy_points: - d.rectangle(f.bounding(4), outline=color) - if face_centroid is not None: - color = RED - box = face_centroid.bounding(max_size * face_centroid.weight) - d.text((box[0], box[1]-15), f"Face: {face_centroid.weight:.02f}", fill=color) - d.ellipse(box, outline=color) - if len(face_points) > 1: - for f in face_points: - d.rectangle(f.bounding(4), outline=color) + d = ImageDraw.Draw(im) + max_size = min(im.width, im.height) * 0.07 + if corner_centroid is not None: + color = BLUE + box = corner_centroid.bounding(max_size * corner_centroid.weight) + d.text((box[0], box[1] - 15), f"Edge: {corner_centroid.weight:.02f}", fill=color) + d.ellipse(box, outline=color) + if len(corner_points) > 1: + for f in corner_points: + d.rectangle(f.bounding(4), outline=color) + if entropy_centroid is not None: + color = "#ff0" + box = entropy_centroid.bounding(max_size * entropy_centroid.weight) + d.text((box[0], box[1] - 15), f"Entropy: {entropy_centroid.weight:.02f}", fill=color) + d.ellipse(box, outline=color) + if len(entropy_points) > 1: + for f in entropy_points: + d.rectangle(f.bounding(4), outline=color) + if face_centroid is not None: + color = RED + box = face_centroid.bounding(max_size * face_centroid.weight) + d.text((box[0], box[1] - 15), f"Face: {face_centroid.weight:.02f}", fill=color) + d.ellipse(box, outline=color) + if len(face_points) > 1: + for f in face_points: + d.rectangle(f.bounding(4), outline=color) - d.ellipse(average_point.bounding(max_size), outline=GREEN) + d.ellipse(average_point.bounding(max_size), outline=GREEN) return average_point def image_face_points(im, settings): if settings.dnn_model_path is not None: - detector = cv2.FaceDetectorYN.create( - settings.dnn_model_path, - "", - (im.width, im.height), - 0.9, # score threshold - 0.3, # nms threshold - 5000 # keep top k before nms - ) - faces = detector.detect(np.array(im)) - results = [] - if faces[1] is not None: - for face in faces[1]: - x = face[0] - y = face[1] - w = face[2] - h = face[3] - results.append( - PointOfInterest( - int(x + (w * 0.5)), # face focus left/right is center - int(y + (h * 0.33)), # face focus up/down is close to the top of the head - size = w, - weight = 1/len(faces[1]) - ) - ) - return results + detector = cv2.FaceDetectorYN.create( + settings.dnn_model_path, + "", + (im.width, im.height), + 0.9, # score threshold + 0.3, # nms threshold + 5000 # keep top k before nms + ) + faces = detector.detect(np.array(im)) + results = [] + if faces[1] is not None: + for face in faces[1]: + x = face[0] + y = face[1] + w = face[2] + h = face[3] + results.append( + PointOfInterest( + int(x + (w * 0.5)), # face focus left/right is center + int(y + (h * 0.33)), # face focus up/down is close to the top of the head + size=w, + weight=1 / len(faces[1]) + ) + ) + return results else: - np_im = np.array(im) - gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY) + np_im = np.array(im) + gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY) - tries = [ - [ f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01 ], - [ f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05 ], - [ f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05 ], - [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05 ], - [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05 ], - [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05 ], - [ f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05 ], - [ f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05 ] - ] - for t in tries: - classifier = cv2.CascadeClassifier(t[0]) - minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side - try: - faces = classifier.detectMultiScale(gray, scaleFactor=1.1, - minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE) - except Exception: - continue + tries = [ + [f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01], + [f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05], + [f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05], + [f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05], + [f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05], + [f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05], + [f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05], + [f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05] + ] + for t in tries: + classifier = cv2.CascadeClassifier(t[0]) + minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side + try: + faces = classifier.detectMultiScale(gray, scaleFactor=1.1, + minNeighbors=7, minSize=(minsize, minsize), + flags=cv2.CASCADE_SCALE_IMAGE) + except Exception: + continue - if faces: - rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces] - return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2]), weight=1/len(rects)) for r in rects] + if faces: + rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces] + return [PointOfInterest((r[0] + r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0] - r[2]), + weight=1 / len(rects)) for r in rects] return [] @@ -198,7 +202,7 @@ def image_corner_points(im, settings): # naive attempt at preventing focal points from collecting at watermarks near the bottom gd = ImageDraw.Draw(grayscale) - gd.rectangle([0, im.height*.9, im.width, im.height], fill="#999") + gd.rectangle([0, im.height * .9, im.width, im.height], fill="#999") np_im = np.array(grayscale) @@ -206,7 +210,7 @@ def image_corner_points(im, settings): np_im, maxCorners=100, qualityLevel=0.04, - minDistance=min(grayscale.width, grayscale.height)*0.06, + minDistance=min(grayscale.width, grayscale.height) * 0.06, useHarrisDetector=False, ) @@ -215,8 +219,8 @@ def image_corner_points(im, settings): focal_points = [] for point in points: - x, y = point.ravel() - focal_points.append(PointOfInterest(x, y, size=4, weight=1/len(points))) + x, y = point.ravel() + focal_points.append(PointOfInterest(x, y, size=4, weight=1 / len(points))) return focal_points @@ -225,13 +229,13 @@ def image_entropy_points(im, settings): landscape = im.height < im.width portrait = im.height > im.width if landscape: - move_idx = [0, 2] - move_max = im.size[0] + move_idx = [0, 2] + move_max = im.size[0] elif portrait: - move_idx = [1, 3] - move_max = im.size[1] + move_idx = [1, 3] + move_max = im.size[1] else: - return [] + return [] e_max = 0 crop_current = [0, 0, settings.crop_width, settings.crop_height] @@ -241,14 +245,14 @@ def image_entropy_points(im, settings): e = image_entropy(crop) if (e > e_max): - e_max = e - crop_best = list(crop_current) + e_max = e + crop_best = list(crop_current) crop_current[move_idx[0]] += 4 crop_current[move_idx[1]] += 4 - x_mid = int(crop_best[0] + settings.crop_width/2) - y_mid = int(crop_best[1] + settings.crop_height/2) + x_mid = int(crop_best[0] + settings.crop_width / 2) + y_mid = int(crop_best[1] + settings.crop_height / 2) return [PointOfInterest(x_mid, y_mid, size=25, weight=1.0)] @@ -294,22 +298,23 @@ def is_square(w, h): return w == h -def download_and_cache_models(dirname): - download_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true' - model_file_name = 'face_detection_yunet.onnx' +model_dir_opencv = os.path.join(paths_internal.models_path, 'opencv') +if parse_version(cv2.__version__) >= parse_version('4.8'): + model_file_path = os.path.join(model_dir_opencv, 'face_detection_yunet_2023mar.onnx') + model_url = 'https://github.com/opencv/opencv_zoo/blob/b6e370b10f641879a87890d44e42173077154a05/models/face_detection_yunet/face_detection_yunet_2023mar.onnx?raw=true' +else: + model_file_path = os.path.join(model_dir_opencv, 'face_detection_yunet.onnx') + model_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true' - os.makedirs(dirname, exist_ok=True) - cache_file = os.path.join(dirname, model_file_name) - if not os.path.exists(cache_file): - print(f"downloading face detection model from '{download_url}' to '{cache_file}'") - response = requests.get(download_url) - with open(cache_file, "wb") as f: +def download_and_cache_models(): + if not os.path.exists(model_file_path): + os.makedirs(model_dir_opencv, exist_ok=True) + print(f"downloading face detection model from '{model_url}' to '{model_file_path}'") + response = requests.get(model_url) + with open(model_file_path, "wb") as f: f.write(response.content) - - if os.path.exists(cache_file): - return cache_file - return None + return model_file_path class PointOfInterest: diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py deleted file mode 100644 index dbd856bd8..000000000 --- a/modules/textual_inversion/preprocess.py +++ /dev/null @@ -1,232 +0,0 @@ -import os -from PIL import Image, ImageOps -import math -import tqdm - -from modules import paths, shared, images, deepbooru -from modules.textual_inversion import autocrop - - -def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.15, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None): - try: - if process_caption: - shared.interrogator.load() - - if process_caption_deepbooru: - deepbooru.model.start() - - preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio, process_focal_crop, process_focal_crop_face_weight, process_focal_crop_entropy_weight, process_focal_crop_edges_weight, process_focal_crop_debug, process_multicrop, process_multicrop_mindim, process_multicrop_maxdim, process_multicrop_minarea, process_multicrop_maxarea, process_multicrop_objective, process_multicrop_threshold) - - finally: - - if process_caption: - shared.interrogator.send_blip_to_ram() - - if process_caption_deepbooru: - deepbooru.model.stop() - - -def listfiles(dirname): - return os.listdir(dirname) - - -class PreprocessParams: - src = None - dstdir = None - subindex = 0 - flip = False - process_caption = False - process_caption_deepbooru = False - preprocess_txt_action = None - - -def save_pic_with_caption(image, index, params: PreprocessParams, existing_caption=None): - caption = "" - - if params.process_caption: - caption += shared.interrogator.generate_caption(image) - - if params.process_caption_deepbooru: - if caption: - caption += ", " - caption += deepbooru.model.tag_multi(image) - - filename_part = params.src - filename_part = os.path.splitext(filename_part)[0] - filename_part = os.path.basename(filename_part) - - basename = f"{index:05}-{params.subindex}-{filename_part}" - image.save(os.path.join(params.dstdir, f"{basename}.png")) - - if params.preprocess_txt_action == 'prepend' and existing_caption: - caption = f"{existing_caption} {caption}" - elif params.preprocess_txt_action == 'append' and existing_caption: - caption = f"{caption} {existing_caption}" - elif params.preprocess_txt_action == 'copy' and existing_caption: - caption = existing_caption - - caption = caption.strip() - - if caption: - with open(os.path.join(params.dstdir, f"{basename}.txt"), "w", encoding="utf8") as file: - file.write(caption) - - params.subindex += 1 - - -def save_pic(image, index, params, existing_caption=None): - save_pic_with_caption(image, index, params, existing_caption=existing_caption) - - if params.flip: - save_pic_with_caption(ImageOps.mirror(image), index, params, existing_caption=existing_caption) - - -def split_pic(image, inverse_xy, width, height, overlap_ratio): - if inverse_xy: - from_w, from_h = image.height, image.width - to_w, to_h = height, width - else: - from_w, from_h = image.width, image.height - to_w, to_h = width, height - h = from_h * to_w // from_w - if inverse_xy: - image = image.resize((h, to_w)) - else: - image = image.resize((to_w, h)) - - split_count = math.ceil((h - to_h * overlap_ratio) / (to_h * (1.0 - overlap_ratio))) - y_step = (h - to_h) / (split_count - 1) - for i in range(split_count): - y = int(y_step * i) - if inverse_xy: - splitted = image.crop((y, 0, y + to_h, to_w)) - else: - splitted = image.crop((0, y, to_w, y + to_h)) - yield splitted - -# not using torchvision.transforms.CenterCrop because it doesn't allow float regions -def center_crop(image: Image, w: int, h: int): - iw, ih = image.size - if ih / h < iw / w: - sw = w * ih / h - box = (iw - sw) / 2, 0, iw - (iw - sw) / 2, ih - else: - sh = h * iw / w - box = 0, (ih - sh) / 2, iw, ih - (ih - sh) / 2 - return image.resize((w, h), Image.Resampling.LANCZOS, box) - - -def multicrop_pic(image: Image, mindim, maxdim, minarea, maxarea, objective, threshold): - iw, ih = image.size - err = lambda w, h: 1-(lambda x: x if x < 1 else 1/x)(iw/ih/(w/h)) - wh = max(((w, h) for w in range(mindim, maxdim+1, 64) for h in range(mindim, maxdim+1, 64) - if minarea <= w * h <= maxarea and err(w, h) <= threshold), - key= lambda wh: (wh[0]*wh[1], -err(*wh))[::1 if objective=='Maximize area' else -1], - default=None - ) - return wh and center_crop(image, *wh) - - -def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None): - width = process_width - height = process_height - src = os.path.abspath(process_src) - dst = os.path.abspath(process_dst) - split_threshold = max(0.0, min(1.0, split_threshold)) - overlap_ratio = max(0.0, min(0.9, overlap_ratio)) - - assert src != dst, 'same directory specified as source and destination' - - os.makedirs(dst, exist_ok=True) - - files = listfiles(src) - - shared.state.job = "preprocess" - shared.state.textinfo = "Preprocessing..." - shared.state.job_count = len(files) - - params = PreprocessParams() - params.dstdir = dst - params.flip = process_flip - params.process_caption = process_caption - params.process_caption_deepbooru = process_caption_deepbooru - params.preprocess_txt_action = preprocess_txt_action - - pbar = tqdm.tqdm(files) - for index, imagefile in enumerate(pbar): - params.subindex = 0 - filename = os.path.join(src, imagefile) - try: - img = Image.open(filename) - img = ImageOps.exif_transpose(img) - img = img.convert("RGB") - except Exception: - continue - - description = f"Preprocessing [Image {index}/{len(files)}]" - pbar.set_description(description) - shared.state.textinfo = description - - params.src = filename - - existing_caption = None - existing_caption_filename = f"{os.path.splitext(filename)[0]}.txt" - if os.path.exists(existing_caption_filename): - with open(existing_caption_filename, 'r', encoding="utf8") as file: - existing_caption = file.read() - - if shared.state.interrupted: - break - - if img.height > img.width: - ratio = (img.width * height) / (img.height * width) - inverse_xy = False - else: - ratio = (img.height * width) / (img.width * height) - inverse_xy = True - - process_default_resize = True - - if process_split and ratio < 1.0 and ratio <= split_threshold: - for splitted in split_pic(img, inverse_xy, width, height, overlap_ratio): - save_pic(splitted, index, params, existing_caption=existing_caption) - process_default_resize = False - - if process_focal_crop and img.height != img.width: - - dnn_model_path = None - try: - dnn_model_path = autocrop.download_and_cache_models(os.path.join(paths.models_path, "opencv")) - except Exception as e: - print("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", e) - - autocrop_settings = autocrop.Settings( - crop_width = width, - crop_height = height, - face_points_weight = process_focal_crop_face_weight, - entropy_points_weight = process_focal_crop_entropy_weight, - corner_points_weight = process_focal_crop_edges_weight, - annotate_image = process_focal_crop_debug, - dnn_model_path = dnn_model_path, - ) - for focal in autocrop.crop_image(img, autocrop_settings): - save_pic(focal, index, params, existing_caption=existing_caption) - process_default_resize = False - - if process_multicrop: - cropped = multicrop_pic(img, process_multicrop_mindim, process_multicrop_maxdim, process_multicrop_minarea, process_multicrop_maxarea, process_multicrop_objective, process_multicrop_threshold) - if cropped is not None: - save_pic(cropped, index, params, existing_caption=existing_caption) - else: - print(f"skipped {img.width}x{img.height} image {filename} (can't find suitable size within error threshold)") - process_default_resize = False - - if process_keep_original_size: - save_pic(img, index, params, existing_caption=existing_caption) - process_default_resize = False - - if process_default_resize: - img = images.resize_image(1, img, width, height) - save_pic(img, index, params, existing_caption=existing_caption) - - shared.state.nextjob() diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 04dda585c..c6bcab153 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -11,7 +11,6 @@ import safetensors.torch import numpy as np from PIL import Image, PngImagePlugin -from torch.utils.tensorboard import SummaryWriter from modules import shared, devices, sd_hijack, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes import modules.textual_inversion.dataset @@ -344,6 +343,7 @@ def write_loss(log_directory, filename, step, epoch_len, values): }) def tensorboard_setup(log_directory): + from torch.utils.tensorboard import SummaryWriter os.makedirs(os.path.join(log_directory, "tensorboard"), exist_ok=True) return SummaryWriter( log_dir=os.path.join(log_directory, "tensorboard"), @@ -448,8 +448,12 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." old_parallel_processing_allowed = shared.parallel_processing_allowed + tensorboard_writer = None if shared.opts.training_enable_tensorboard: - tensorboard_writer = tensorboard_setup(log_directory) + try: + tensorboard_writer = tensorboard_setup(log_directory) + except ImportError: + errors.report("Error initializing tensorboard", exc_info=True) pin_memory = shared.opts.pin_memory @@ -622,7 +626,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) last_saved_image += f", prompt: {preview_text}" - if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images: + if tensorboard_writer and shared.opts.training_tensorboard_save_images: tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, embedding.step) if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded: diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py index 35c4feeff..f149ad1f0 100644 --- a/modules/textual_inversion/ui.py +++ b/modules/textual_inversion/ui.py @@ -3,7 +3,6 @@ import html import gradio as gr import modules.textual_inversion.textual_inversion -import modules.textual_inversion.preprocess from modules import sd_hijack, shared @@ -15,12 +14,6 @@ def create_embedding(name, initialization_text, nvpt, overwrite_old): return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f"Created: {filename}", "" -def preprocess(*args): - modules.textual_inversion.preprocess.preprocess(*args) - - return f"Preprocessing {'interrupted' if shared.state.interrupted else 'finished'}.", "" - - def train_embedding(*args): assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible' diff --git a/modules/torch_utils.py b/modules/torch_utils.py new file mode 100644 index 000000000..e5b52393e --- /dev/null +++ b/modules/torch_utils.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +import torch.nn + + +def get_param(model) -> torch.nn.Parameter: + """ + Find the first parameter in a model or module. + """ + if hasattr(model, "model") and hasattr(model.model, "parameters"): + # Unpeel a model descriptor to get at the actual Torch module. + model = model.model + + for param in model.parameters(): + return param + + raise ValueError(f"No parameters found in model {model!r}") diff --git a/modules/txt2img.py b/modules/txt2img.py index e4e18ceb6..3a481915f 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -2,7 +2,7 @@ from contextlib import closing import modules.scripts from modules import processing -from modules.generation_parameters_copypaste import create_override_settings_dict +from modules.infotext import create_override_settings_dict from modules.shared import opts import modules.shared as shared from modules.ui import plaintext_to_html diff --git a/modules/ui.py b/modules/ui.py index bcf391997..378529c79 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -4,6 +4,7 @@ import os import sys from functools import reduce import warnings +from contextlib import ExitStack import gradio as gr import gradio.utils @@ -12,7 +13,7 @@ from PIL import Image, PngImagePlugin # noqa: F401 from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call from modules import gradio_extensons # noqa: F401 -from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, ui_prompt_styles, scripts, sd_samplers, processing, ui_extra_networks +from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, scripts, sd_samplers, processing, ui_extra_networks, ui_toprow from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML, InputAccordion, ResizeHandleRow from modules.paths import script_path from modules.ui_common import create_refresh_button @@ -20,15 +21,14 @@ from modules.ui_gradio_extensions import reload_javascript from modules.shared import opts, cmd_opts -import modules.generation_parameters_copypaste as parameters_copypaste +import modules.infotext as parameters_copypaste import modules.hypernetworks.ui as hypernetworks_ui import modules.textual_inversion.ui as textual_inversion_ui import modules.textual_inversion.textual_inversion as textual_inversion import modules.shared as shared -import modules.images from modules import prompt_parser from modules.sd_hijack import model_hijack -from modules.generation_parameters_copypaste import image_from_url_text +from modules.infotext import image_from_url_text, PasteField create_setting_component = ui_settings.create_setting_component @@ -177,80 +177,6 @@ def update_negative_prompt_token_counter(text, steps): return update_token_counter(text, steps, is_positive=False) -class Toprow: - """Creates a top row UI with prompts, generate button, styles, extra little buttons for things, and enables some functionality related to their operation""" - - def __init__(self, is_img2img): - id_part = "img2img" if is_img2img else "txt2img" - self.id_part = id_part - - with gr.Row(elem_id=f"{id_part}_toprow", variant="compact"): - with gr.Column(elem_id=f"{id_part}_prompt_container", scale=6): - with gr.Row(): - with gr.Column(scale=80): - with gr.Row(): - self.prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"]) - self.prompt_img = gr.File(label="", elem_id=f"{id_part}_prompt_image", file_count="single", type="binary", visible=False) - - with gr.Row(): - with gr.Column(scale=80): - with gr.Row(): - self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"]) - - self.button_interrogate = None - self.button_deepbooru = None - if is_img2img: - with gr.Column(scale=1, elem_classes="interrogate-col"): - self.button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate") - self.button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru") - - with gr.Column(scale=1, elem_id=f"{id_part}_actions_column"): - with gr.Row(elem_id=f"{id_part}_generate_box", elem_classes="generate-box"): - self.interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt", elem_classes="generate-box-interrupt") - self.skip = gr.Button('Skip', elem_id=f"{id_part}_skip", elem_classes="generate-box-skip") - self.submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary') - - self.skip.click( - fn=lambda: shared.state.skip(), - inputs=[], - outputs=[], - ) - - self.interrupt.click( - fn=lambda: shared.state.interrupt(), - inputs=[], - outputs=[], - ) - - with gr.Row(elem_id=f"{id_part}_tools"): - self.paste = ToolButton(value=paste_symbol, elem_id="paste", tooltip="Read generation parameters from prompt or last generation if prompt is empty into user interface.") - self.clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt", tooltip="Clear prompt") - self.apply_styles = ToolButton(value=ui_prompt_styles.styles_materialize_symbol, elem_id=f"{id_part}_style_apply", tooltip="Apply all selected styles to prompts.") - self.restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{id_part}_restore_progress", visible=False, tooltip="Restore progress") - - self.token_counter = gr.HTML(value="0/75", elem_id=f"{id_part}_token_counter", elem_classes=["token-counter"]) - self.token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") - self.negative_token_counter = gr.HTML(value="0/75", elem_id=f"{id_part}_negative_token_counter", elem_classes=["token-counter"]) - self.negative_token_button = gr.Button(visible=False, elem_id=f"{id_part}_negative_token_button") - - self.clear_prompt_button.click( - fn=lambda *x: x, - _js="confirm_clear_prompt", - inputs=[self.prompt, self.negative_prompt], - outputs=[self.prompt, self.negative_prompt], - ) - - self.ui_styles = ui_prompt_styles.UiPromptStyles(id_part, self.prompt, self.negative_prompt) - self.ui_styles.setup_apply_button(self.apply_styles) - - self.prompt_img.change( - fn=modules.images.image_data, - inputs=[self.prompt_img], - outputs=[self.prompt, self.prompt_img], - show_progress=False, - ) - - def setup_progressbar(*args, **kwargs): pass @@ -288,8 +214,8 @@ def apply_setting(key, value): return getattr(opts, key) -def create_output_panel(tabname, outdir): - return ui_common.create_output_panel(tabname, outdir) +def create_output_panel(tabname, outdir, toprow=None): + return ui_common.create_output_panel(tabname, outdir, toprow) def create_sampler_and_steps_selection(choices, tabname): @@ -336,7 +262,7 @@ def create_ui(): scripts.scripts_txt2img.initialize_scripts(is_img2img=False) with gr.Blocks(analytics_enabled=False) as txt2img_interface: - toprow = Toprow(is_img2img=False) + toprow = ui_toprow.Toprow(is_img2img=False, is_compact=shared.opts.compact_prompt_box) dummy_component = gr.Label(visible=False) @@ -344,10 +270,17 @@ def create_ui(): extra_tabs.__enter__() with gr.Tab("Generation", id="txt2img_generation") as txt2img_generation_tab, ResizeHandleRow(equal_height=False): - with gr.Column(variant='compact', elem_id="txt2img_settings"): + with ExitStack() as stack: + if shared.opts.txt2img_settings_accordion: + stack.enter_context(gr.Accordion("Open for Settings", open=False)) + stack.enter_context(gr.Column(variant='compact', elem_id="txt2img_settings")) + scripts.scripts_txt2img.prepare_ui() for category in ordered_ui_categories(): + if category == "prompt": + toprow.create_inline_toprow_prompts() + if category == "sampler": steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "txt2img") @@ -442,7 +375,7 @@ def create_ui(): show_progress=False, ) - txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples) + txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples, toprow) txt2img_args = dict( fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']), @@ -502,28 +435,28 @@ def create_ui(): ) txt2img_paste_fields = [ - (toprow.prompt, "Prompt"), - (toprow.negative_prompt, "Negative prompt"), - (steps, "Steps"), - (sampler_name, "Sampler"), - (cfg_scale, "CFG scale"), - (width, "Size-1"), - (height, "Size-2"), - (batch_size, "Batch size"), - (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()), - (denoising_strength, "Denoising strength"), - (enable_hr, lambda d: "Denoising strength" in d and ("Hires upscale" in d or "Hires upscaler" in d or "Hires resize-1" in d)), - (hr_scale, "Hires upscale"), - (hr_upscaler, "Hires upscaler"), - (hr_second_pass_steps, "Hires steps"), - (hr_resize_x, "Hires resize-1"), - (hr_resize_y, "Hires resize-2"), - (hr_checkpoint_name, "Hires checkpoint"), - (hr_sampler_name, "Hires sampler"), - (hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" else gr.update()), - (hr_prompt, "Hires prompt"), - (hr_negative_prompt, "Hires negative prompt"), - (hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()), + PasteField(toprow.prompt, "Prompt", api="prompt"), + PasteField(toprow.negative_prompt, "Negative prompt", api="negative_prompt"), + PasteField(steps, "Steps", api="steps"), + PasteField(sampler_name, "Sampler", api="sampler_name"), + PasteField(cfg_scale, "CFG scale", api="cfg_scale"), + PasteField(width, "Size-1", api="width"), + PasteField(height, "Size-2", api="height"), + PasteField(batch_size, "Batch size", api="batch_size"), + PasteField(toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update(), api="styles"), + PasteField(denoising_strength, "Denoising strength", api="denoising_strength"), + PasteField(enable_hr, lambda d: "Denoising strength" in d and ("Hires upscale" in d or "Hires upscaler" in d or "Hires resize-1" in d), api="enable_hr"), + PasteField(hr_scale, "Hires upscale", api="hr_scale"), + PasteField(hr_upscaler, "Hires upscaler", api="hr_upscaler"), + PasteField(hr_second_pass_steps, "Hires steps", api="hr_second_pass_steps"), + PasteField(hr_resize_x, "Hires resize-1", api="hr_resize_x"), + PasteField(hr_resize_y, "Hires resize-2", api="hr_resize_y"), + PasteField(hr_checkpoint_name, "Hires checkpoint", api="hr_checkpoint_name"), + PasteField(hr_sampler_name, "Hires sampler", api="hr_sampler_name"), + PasteField(hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" else gr.update()), + PasteField(hr_prompt, "Hires prompt", api="hr_prompt"), + PasteField(hr_negative_prompt, "Hires negative prompt", api="hr_negative_prompt"), + PasteField(hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()), *scripts.scripts_txt2img.infotext_fields ] parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields, override_settings) @@ -554,13 +487,17 @@ def create_ui(): scripts.scripts_img2img.initialize_scripts(is_img2img=True) with gr.Blocks(analytics_enabled=False) as img2img_interface: - toprow = Toprow(is_img2img=True) + toprow = ui_toprow.Toprow(is_img2img=True, is_compact=shared.opts.compact_prompt_box) extra_tabs = gr.Tabs(elem_id="img2img_extra_tabs") extra_tabs.__enter__() with gr.Tab("Generation", id="img2img_generation") as img2img_generation_tab, ResizeHandleRow(equal_height=False): - with gr.Column(variant='compact', elem_id="img2img_settings"): + with ExitStack() as stack: + if shared.opts.img2img_settings_accordion: + stack.enter_context(gr.Accordion("Open for Settings", open=False)) + stack.enter_context(gr.Column(variant='compact', elem_id="img2img_settings")) + copy_image_buttons = [] copy_image_destinations = {} @@ -577,85 +514,89 @@ def create_ui(): button = gr.Button(title) copy_image_buttons.append((button, name, elem)) - with gr.Tabs(elem_id="mode_img2img"): - img2img_selected_tab = gr.State(0) - - with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img: - init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA", height=opts.img2img_editor_height) - add_copy_image_controls('img2img', init_img) - - with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch: - sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_sketch_default_brush_color) - add_copy_image_controls('sketch', sketch) - - with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint: - init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_mask_brush_color) - add_copy_image_controls('inpaint', init_img_with_mask) - - with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color: - inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_sketch_default_brush_color) - inpaint_color_sketch_orig = gr.State(None) - add_copy_image_controls('inpaint_sketch', inpaint_color_sketch) - - def update_orig(image, state): - if image is not None: - same_size = state is not None and state.size == image.size - has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1)) - edited = same_size and has_exact_match - return image if not edited or state is None else state - - inpaint_color_sketch.change(update_orig, [inpaint_color_sketch, inpaint_color_sketch_orig], inpaint_color_sketch_orig) - - with gr.TabItem('Inpaint upload', id='inpaint_upload', elem_id="img2img_inpaint_upload_tab") as tab_inpaint_upload: - init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", elem_id="img_inpaint_base") - init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", image_mode="RGBA", elem_id="img_inpaint_mask") - - with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch: - hidden = '
    Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' - gr.HTML( - "

    Process images in a directory on the same machine where the server is running." + - "
    Use an empty output directory to save pictures normally instead of writing to the output directory." + - f"
    Add inpaint batch mask directory to enable inpaint batch processing." - f"{hidden}

    " - ) - img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir") - img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir") - img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir") - with gr.Accordion("PNG info", open=False): - img2img_batch_use_png_info = gr.Checkbox(label="Append png info to prompts", **shared.hide_dirs, elem_id="img2img_batch_use_png_info") - img2img_batch_png_info_dir = gr.Textbox(label="PNG info directory", **shared.hide_dirs, placeholder="Leave empty to use input directory", elem_id="img2img_batch_png_info_dir") - img2img_batch_png_info_props = gr.CheckboxGroup(["Prompt", "Negative prompt", "Seed", "CFG scale", "Sampler", "Steps", "Model hash"], label="Parameters to take from png info", info="Prompts from png info will be appended to prompts set in ui.") - - img2img_tabs = [tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch] - - for i, tab in enumerate(img2img_tabs): - tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[img2img_selected_tab]) - - def copy_image(img): - if isinstance(img, dict) and 'image' in img: - return img['image'] - - return img - - for button, name, elem in copy_image_buttons: - button.click( - fn=copy_image, - inputs=[elem], - outputs=[copy_image_destinations[name]], - ) - button.click( - fn=lambda: None, - _js=f"switch_to_{name.replace(' ', '_')}", - inputs=[], - outputs=[], - ) - - with FormRow(): - resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize") - scripts.scripts_img2img.prepare_ui() for category in ordered_ui_categories(): + if category == "prompt": + toprow.create_inline_toprow_prompts() + + if category == "image": + with gr.Tabs(elem_id="mode_img2img"): + img2img_selected_tab = gr.State(0) + + with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img: + init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA", height=opts.img2img_editor_height) + add_copy_image_controls('img2img', init_img) + + with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch: + sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_sketch_default_brush_color) + add_copy_image_controls('sketch', sketch) + + with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint: + init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_mask_brush_color) + add_copy_image_controls('inpaint', init_img_with_mask) + + with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color: + inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_sketch_default_brush_color) + inpaint_color_sketch_orig = gr.State(None) + add_copy_image_controls('inpaint_sketch', inpaint_color_sketch) + + def update_orig(image, state): + if image is not None: + same_size = state is not None and state.size == image.size + has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1)) + edited = same_size and has_exact_match + return image if not edited or state is None else state + + inpaint_color_sketch.change(update_orig, [inpaint_color_sketch, inpaint_color_sketch_orig], inpaint_color_sketch_orig) + + with gr.TabItem('Inpaint upload', id='inpaint_upload', elem_id="img2img_inpaint_upload_tab") as tab_inpaint_upload: + init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", elem_id="img_inpaint_base") + init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", image_mode="RGBA", elem_id="img_inpaint_mask") + + with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch: + hidden = '
    Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' + gr.HTML( + "

    Process images in a directory on the same machine where the server is running." + + "
    Use an empty output directory to save pictures normally instead of writing to the output directory." + + f"
    Add inpaint batch mask directory to enable inpaint batch processing." + f"{hidden}

    " + ) + img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir") + img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir") + img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir") + with gr.Accordion("PNG info", open=False): + img2img_batch_use_png_info = gr.Checkbox(label="Append png info to prompts", **shared.hide_dirs, elem_id="img2img_batch_use_png_info") + img2img_batch_png_info_dir = gr.Textbox(label="PNG info directory", **shared.hide_dirs, placeholder="Leave empty to use input directory", elem_id="img2img_batch_png_info_dir") + img2img_batch_png_info_props = gr.CheckboxGroup(["Prompt", "Negative prompt", "Seed", "CFG scale", "Sampler", "Steps", "Model hash"], label="Parameters to take from png info", info="Prompts from png info will be appended to prompts set in ui.") + + img2img_tabs = [tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch] + + for i, tab in enumerate(img2img_tabs): + tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[img2img_selected_tab]) + + def copy_image(img): + if isinstance(img, dict) and 'image' in img: + return img['image'] + + return img + + for button, name, elem in copy_image_buttons: + button.click( + fn=copy_image, + inputs=[elem], + outputs=[copy_image_destinations[name]], + ) + button.click( + fn=lambda: None, + _js=f"switch_to_{name.replace(' ', '_')}", + inputs=[], + outputs=[], + ) + + with FormRow(): + resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize") + if category == "sampler": steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "img2img") @@ -693,12 +634,6 @@ def create_ui(): scale_by.release(**on_change_args) button_update_resize_to.click(**on_change_args) - # the code below is meant to update the resolution label after the image in the image selection UI has changed. - # as it is now the event keeps firing continuously for inpaint edits, which ruins the page with constant requests. - # I assume this must be a gradio bug and for now we'll just do it for non-inpaint inputs. - for component in [init_img, sketch]: - component.change(fn=lambda: None, _js="updateImg2imgResizeToTextAfterChangingImage", inputs=[], outputs=[], show_progress=False) - tab_scale_to.select(fn=lambda: 0, inputs=[], outputs=[selected_scale_tab]) tab_scale_by.select(fn=lambda: 1, inputs=[], outputs=[selected_scale_tab]) @@ -756,20 +691,26 @@ def create_ui(): with gr.Column(scale=4): inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding") - def select_img2img_tab(tab): - return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3), - - for i, elem in enumerate(img2img_tabs): - elem.select( - fn=lambda tab=i: select_img2img_tab(tab), - inputs=[], - outputs=[inpaint_controls, mask_alpha], - ) - if category not in {"accordions"}: scripts.scripts_img2img.setup_ui_for_section(category) - img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples) + # the code below is meant to update the resolution label after the image in the image selection UI has changed. + # as it is now the event keeps firing continuously for inpaint edits, which ruins the page with constant requests. + # I assume this must be a gradio bug and for now we'll just do it for non-inpaint inputs. + for component in [init_img, sketch]: + component.change(fn=lambda: None, _js="updateImg2imgResizeToTextAfterChangingImage", inputs=[], outputs=[], show_progress=False) + + def select_img2img_tab(tab): + return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3), + + for i, elem in enumerate(img2img_tabs): + elem.select( + fn=lambda tab=i: select_img2img_tab(tab), + inputs=[], + outputs=[inpaint_controls, mask_alpha], + ) + + img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples, toprow) img2img_args = dict( fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']), @@ -970,71 +911,6 @@ def create_ui(): with gr.Column(): create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary', elem_id="train_create_hypernetwork") - with gr.Tab(label="Preprocess images", id="preprocess_images"): - process_src = gr.Textbox(label='Source directory', elem_id="train_process_src") - process_dst = gr.Textbox(label='Destination directory', elem_id="train_process_dst") - process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_process_width") - process_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_process_height") - preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"], elem_id="train_preprocess_txt_action") - - with gr.Row(): - process_keep_original_size = gr.Checkbox(label='Keep original size', elem_id="train_process_keep_original_size") - process_flip = gr.Checkbox(label='Create flipped copies', elem_id="train_process_flip") - process_split = gr.Checkbox(label='Split oversized images', elem_id="train_process_split") - process_focal_crop = gr.Checkbox(label='Auto focal point crop', elem_id="train_process_focal_crop") - process_multicrop = gr.Checkbox(label='Auto-sized crop', elem_id="train_process_multicrop") - process_caption = gr.Checkbox(label='Use BLIP for caption', elem_id="train_process_caption") - process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True, elem_id="train_process_caption_deepbooru") - - with gr.Row(visible=False) as process_split_extra_row: - process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_split_threshold") - process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05, elem_id="train_process_overlap_ratio") - - with gr.Row(visible=False) as process_focal_crop_row: - process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_face_weight") - process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_entropy_weight") - process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_edges_weight") - process_focal_crop_debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug") - - with gr.Column(visible=False) as process_multicrop_col: - gr.Markdown('Each image is center-cropped with an automatically chosen width and height.') - with gr.Row(): - process_multicrop_mindim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension lower bound", value=384, elem_id="train_process_multicrop_mindim") - process_multicrop_maxdim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension upper bound", value=768, elem_id="train_process_multicrop_maxdim") - with gr.Row(): - process_multicrop_minarea = gr.Slider(minimum=64*64, maximum=2048*2048, step=1, label="Area lower bound", value=64*64, elem_id="train_process_multicrop_minarea") - process_multicrop_maxarea = gr.Slider(minimum=64*64, maximum=2048*2048, step=1, label="Area upper bound", value=640*640, elem_id="train_process_multicrop_maxarea") - with gr.Row(): - process_multicrop_objective = gr.Radio(["Maximize area", "Minimize error"], value="Maximize area", label="Resizing objective", elem_id="train_process_multicrop_objective") - process_multicrop_threshold = gr.Slider(minimum=0, maximum=1, step=0.01, label="Error threshold", value=0.1, elem_id="train_process_multicrop_threshold") - - with gr.Row(): - with gr.Column(scale=3): - gr.HTML(value="") - - with gr.Column(): - with gr.Row(): - interrupt_preprocessing = gr.Button("Interrupt", elem_id="train_interrupt_preprocessing") - run_preprocess = gr.Button(value="Preprocess", variant='primary', elem_id="train_run_preprocess") - - process_split.change( - fn=lambda show: gr_show(show), - inputs=[process_split], - outputs=[process_split_extra_row], - ) - - process_focal_crop.change( - fn=lambda show: gr_show(show), - inputs=[process_focal_crop], - outputs=[process_focal_crop_row], - ) - - process_multicrop.change( - fn=lambda show: gr_show(show), - inputs=[process_multicrop], - outputs=[process_multicrop_col], - ) - def get_textual_inversion_template_names(): return sorted(textual_inversion.textual_inversion_templates) @@ -1135,42 +1011,6 @@ def create_ui(): ] ) - run_preprocess.click( - fn=wrap_gradio_gpu_call(textual_inversion_ui.preprocess, extra_outputs=[gr.update()]), - _js="start_training_textual_inversion", - inputs=[ - dummy_component, - process_src, - process_dst, - process_width, - process_height, - preprocess_txt_action, - process_keep_original_size, - process_flip, - process_split, - process_caption, - process_caption_deepbooru, - process_split_threshold, - process_overlap_ratio, - process_focal_crop, - process_focal_crop_face_weight, - process_focal_crop_entropy_weight, - process_focal_crop_edges_weight, - process_focal_crop_debug, - process_multicrop, - process_multicrop_mindim, - process_multicrop_maxdim, - process_multicrop_minarea, - process_multicrop_maxarea, - process_multicrop_objective, - process_multicrop_threshold, - ], - outputs=[ - ti_output, - ti_outcome, - ], - ) - train_embedding.click( fn=wrap_gradio_gpu_call(textual_inversion_ui.train_embedding, extra_outputs=[gr.update()]), _js="start_training_textual_inversion", @@ -1244,13 +1084,8 @@ def create_ui(): outputs=[], ) - interrupt_preprocessing.click( - fn=lambda: shared.state.interrupt(), - inputs=[], - outputs=[], - ) - loadsave = ui_loadsave.UiLoadsave(cmd_opts.ui_config_file) + ui_settings_from_file = loadsave.ui_settings.copy() settings = ui_settings.UiSettings() settings.create_ui(loadsave, dummy_component) @@ -1311,7 +1146,8 @@ def create_ui(): modelmerger_ui.setup_ui(dummy_component=dummy_component, sd_model_checkpoint_component=settings.component_dict['sd_model_checkpoint']) - loadsave.dump_defaults() + if ui_settings_from_file != loadsave.ui_settings: + loadsave.dump_defaults() demo.ui_loadsave = loadsave return demo @@ -1366,7 +1202,7 @@ def setup_ui_api(app): from fastapi.responses import PlainTextResponse text = sysinfo.get() - filename = f"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.txt" + filename = f"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.json" return PlainTextResponse(text, headers={'Content-Disposition': f'{"attachment" if attachment else "inline"}; filename="{filename}"'}) diff --git a/modules/ui_common.py b/modules/ui_common.py index 84a7d7f27..fd32676f9 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -8,10 +8,10 @@ import gradio as gr import subprocess as sp from modules import call_queue, shared -from modules.generation_parameters_copypaste import image_from_url_text +from modules.infotext import image_from_url_text import modules.images from modules.ui_components import ToolButton -import modules.generation_parameters_copypaste as parameters_copypaste +import modules.infotext as parameters_copypaste folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 @@ -104,7 +104,7 @@ def save_files(js_data, images, do_make_zip, index): return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}") -def create_output_panel(tabname, outdir): +def create_output_panel(tabname, outdir, toprow=None): def open_folder(f): if not os.path.exists(f): @@ -130,12 +130,15 @@ Requested path was: {f} else: sp.Popen(["xdg-open", path]) - with gr.Column(variant='panel', elem_id=f"{tabname}_results"): - with gr.Group(elem_id=f"{tabname}_gallery_container"): - result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery", columns=4, preview=True, height=shared.opts.gallery_height or None) + with gr.Column(elem_id=f"{tabname}_results"): + if toprow: + toprow.create_inline_toprow_image() - generation_info = None - with gr.Column(): + with gr.Column(variant='panel', elem_id=f"{tabname}_results_panel"): + with gr.Group(elem_id=f"{tabname}_gallery_container"): + result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery", columns=4, preview=True, height=shared.opts.gallery_height or None) + + generation_info = None with gr.Row(elem_id=f"image_buttons_{tabname}", elem_classes="image-buttons"): open_folder_button = ToolButton(folder_symbol, elem_id=f'{tabname}_open_folder', visible=not shared.cmd_opts.hide_ui_dir_config, tooltip="Open images output directory.") diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index c0a73b573..dc1e34c8a 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -65,7 +65,7 @@ def save_config_state(name): filename = os.path.join(config_states_dir, f"{timestamp}_{name}.json") print(f"Saving backup of webui/extension state to {filename}.") with open(filename, "w", encoding="utf-8") as f: - json.dump(current_config_state, f, indent=4) + json.dump(current_config_state, f, indent=4, ensure_ascii=False) config_states.list_config_states() new_value = next(iter(config_states.all_config_states.keys()), "Current") new_choices = ["Current"] + list(config_states.all_config_states.keys()) @@ -335,6 +335,11 @@ def normalize_git_url(url): return url +def get_extension_dirname_from_url(url): + *parts, last_part = url.split('/') + return normalize_git_url(last_part) + + def install_extension_from_url(dirname, url, branch_name=None): check_access() @@ -346,10 +351,7 @@ def install_extension_from_url(dirname, url, branch_name=None): assert url, 'No URL specified' if dirname is None or dirname == "": - *parts, last_part = url.split('/') - last_part = normalize_git_url(last_part) - - dirname = last_part + dirname = get_extension_dirname_from_url(url) target_dir = os.path.join(extensions.extensions_dir, dirname) assert not os.path.exists(target_dir), f'Extension directory already exists: {target_dir}' @@ -449,7 +451,8 @@ def get_date(info: dict, key): def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=""): extlist = available_extensions["extensions"] - installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions} + installed_extensions = {extension.name for extension in extensions.extensions} + installed_extension_urls = {normalize_git_url(extension.remote) for extension in extensions.extensions if extension.remote is not None} tags = available_extensions.get("tags", {}) tags_to_hide = set(hide_tags) @@ -482,7 +485,7 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=" if url is None: continue - existing = installed_extension_urls.get(normalize_git_url(url), None) + existing = get_extension_dirname_from_url(url) in installed_extensions or normalize_git_url(url) in installed_extension_urls extension_tags = extension_tags + ["installed"] if existing else extension_tags if any(x for x in extension_tags if x in tags_to_hide): diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 59d6ecc61..790af1356 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -10,7 +10,7 @@ import json import html from fastapi.exceptions import HTTPException -from modules.generation_parameters_copypaste import image_from_url_text +from modules.infotext import image_from_url_text from modules.ui_components import ToolButton extra_pages = [] @@ -103,6 +103,7 @@ class ExtraNetworksPage: self.name = title.lower() self.id_page = self.name.replace(" ", "_") self.card_page = shared.html("extra-networks-card.html") + self.allow_prompt = True self.allow_negative_prompt = False self.metadata = {} self.items = {} @@ -150,8 +151,13 @@ class ExtraNetworksPage: continue subdir = os.path.abspath(x)[len(parentdir):].replace("\\", "/") - while subdir.startswith("/"): - subdir = subdir[1:] + + if shared.opts.extra_networks_dir_button_function: + if not subdir.startswith("/"): + subdir = "/" + subdir + else: + while subdir.startswith("/"): + subdir = subdir[1:] is_empty = len(os.listdir(x)) == 0 if not is_empty and not subdir.endswith("/"): @@ -217,7 +223,10 @@ class ExtraNetworksPage: onclick = item.get("onclick", None) if onclick is None: - onclick = '"' + html.escape(f"""return cardClicked({quote_js(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"' + if "negative_prompt" in item: + onclick = '"' + html.escape(f"""return cardClicked({quote_js(tabname)}, {item["prompt"]}, {item["negative_prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"' + else: + onclick = '"' + html.escape(f"""return cardClicked({quote_js(tabname)}, {item["prompt"]}, {'""'}, {"true" if self.allow_negative_prompt else "false"})""") + '"' height = f"height: {shared.opts.extra_networks_card_height}px;" if shared.opts.extra_networks_card_height else '' width = f"width: {shared.opts.extra_networks_card_width}px;" if shared.opts.extra_networks_card_width else '' @@ -278,6 +287,7 @@ class ExtraNetworksPage: "date_created": int(stat.st_ctime or 0), "date_modified": int(stat.st_mtime or 0), "name": pth.name.lower(), + "path": str(pth.parent).lower(), } def find_preview(self, path): @@ -367,7 +377,10 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): related_tabs = [] for page in ui.stored_extra_pages: - with gr.Tab(page.title, id=page.id_page) as tab: + with gr.Tab(page.title, elem_id=f"{tabname}_{page.id_page}", elem_classes=["extra-page"]) as tab: + with gr.Column(elem_id=f"{tabname}_{page.id_page}_prompts", elem_classes=["extra-page-prompts"]): + pass + elem_id = f"{tabname}_{page.id_page}_cards_html" page_elem = gr.HTML('Loading...', elem_id=elem_id) ui.pages.append(page_elem) @@ -381,19 +394,28 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): related_tabs.append(tab) edit_search = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", elem_classes="search", placeholder="Search...", visible=False, interactive=True) - dropdown_sort = gr.Dropdown(choices=['Default Sort', 'Date Created', 'Date Modified', 'Name'], value='Default Sort', elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order") - button_sortorder = ToolButton(switch_values_symbol, elem_id=tabname+"_extra_sortorder", elem_classes="sortorder", visible=False, tooltip="Invert sort order") + dropdown_sort = gr.Dropdown(choices=['Path', 'Name', 'Date Created', 'Date Modified', ], value=shared.opts.extra_networks_card_order_field, elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order") + button_sortorder = ToolButton(switch_values_symbol, elem_id=tabname+"_extra_sortorder", elem_classes=["sortorder"] + ([] if shared.opts.extra_networks_card_order == "Ascending" else ["sortReverse"]), visible=False, tooltip="Invert sort order") button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh", visible=False) checkbox_show_dirs = gr.Checkbox(True, label='Show dirs', elem_id=tabname+"_extra_show_dirs", elem_classes="show-dirs", visible=False) ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False) - for tab in unrelated_tabs: - tab.select(fn=lambda: [gr.update(visible=False) for _ in range(5)], inputs=[], outputs=[edit_search, dropdown_sort, button_sortorder, button_refresh, checkbox_show_dirs], show_progress=False) + tab_controls = [edit_search, dropdown_sort, button_sortorder, button_refresh, checkbox_show_dirs] - for tab in related_tabs: - tab.select(fn=lambda: [gr.update(visible=True) for _ in range(5)], inputs=[], outputs=[edit_search, dropdown_sort, button_sortorder, button_refresh, checkbox_show_dirs], show_progress=False) + for tab in unrelated_tabs: + tab.select(fn=lambda: [gr.update(visible=False) for _ in tab_controls], _js='function(){ extraNetworksUrelatedTabSelected("' + tabname + '"); }', inputs=[], outputs=tab_controls, show_progress=False) + + for page, tab in zip(ui.stored_extra_pages, related_tabs): + allow_prompt = "true" if page.allow_prompt else "false" + allow_negative_prompt = "true" if page.allow_negative_prompt else "false" + + jscode = 'extraNetworksTabSelected("' + tabname + '", "' + f"{tabname}_{page.id_page}_prompts" + '", ' + allow_prompt + ', ' + allow_negative_prompt + ');' + + tab.select(fn=lambda: [gr.update(visible=True) for _ in tab_controls], _js='function(){ ' + jscode + ' }', inputs=[], outputs=tab_controls, show_progress=False) + + dropdown_sort.change(fn=lambda: None, _js="function(){ applyExtraNetworkSort('" + tabname + "'); }") def pages_html(): if not ui.pages_contents: diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py index ca6c26076..1693e71f1 100644 --- a/modules/ui_extra_networks_checkpoints.py +++ b/modules/ui_extra_networks_checkpoints.py @@ -10,11 +10,16 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): def __init__(self): super().__init__('Checkpoints') + self.allow_prompt = False + def refresh(self): shared.refresh_checkpoints() def create_item(self, name, index=None, enable_filter=True): checkpoint: sd_models.CheckpointInfo = sd_models.checkpoint_aliases.get(name) + if checkpoint is None: + return + path, ext = os.path.splitext(checkpoint.filename) return { "name": checkpoint.name_for_extra, @@ -30,9 +35,12 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): } def list_items(self): + # instantiate a list to protect against concurrent modification names = list(sd_models.checkpoints_list) for index, name in enumerate(names): - yield self.create_item(name, index) + item = self.create_item(name, index) + if item is not None: + yield item def allowed_directories_for_previews(self): return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None] diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py index 4cedf0851..c96c4fa3b 100644 --- a/modules/ui_extra_networks_hypernets.py +++ b/modules/ui_extra_networks_hypernets.py @@ -13,7 +13,10 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): shared.reload_hypernetworks() def create_item(self, name, index=None, enable_filter=True): - full_path = shared.hypernetworks[name] + full_path = shared.hypernetworks.get(name) + if full_path is None: + return + path, ext = os.path.splitext(full_path) sha256 = sha256_from_cache(full_path, f'hypernet/{name}') shorthash = sha256[0:10] if sha256 else None @@ -31,8 +34,12 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): } def list_items(self): - for index, name in enumerate(shared.hypernetworks): - yield self.create_item(name, index) + # instantiate a list to protect against concurrent modification + names = list(shared.hypernetworks) + for index, name in enumerate(names): + item = self.create_item(name, index) + if item is not None: + yield item def allowed_directories_for_previews(self): return [shared.cmd_opts.hypernetwork_dir] diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py index 55ef0ea7b..1b334fda1 100644 --- a/modules/ui_extra_networks_textual_inversion.py +++ b/modules/ui_extra_networks_textual_inversion.py @@ -14,6 +14,8 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): def create_item(self, name, index=None, enable_filter=True): embedding = sd_hijack.model_hijack.embedding_db.word_embeddings.get(name) + if embedding is None: + return path, ext = os.path.splitext(embedding.filename) return { @@ -29,8 +31,12 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): } def list_items(self): - for index, name in enumerate(sd_hijack.model_hijack.embedding_db.word_embeddings): - yield self.create_item(name, index) + # instantiate a list to protect against concurrent modification + names = list(sd_hijack.model_hijack.embedding_db.word_embeddings) + for index, name in enumerate(names): + item = self.create_item(name, index) + if item is not None: + yield item def allowed_directories_for_previews(self): return list(sd_hijack.model_hijack.embedding_db.embedding_dirs) diff --git a/modules/ui_extra_networks_user_metadata.py b/modules/ui_extra_networks_user_metadata.py index bfec140cc..87aeb6f33 100644 --- a/modules/ui_extra_networks_user_metadata.py +++ b/modules/ui_extra_networks_user_metadata.py @@ -5,7 +5,7 @@ import os.path import gradio as gr -from modules import generation_parameters_copypaste, images, sysinfo, errors, ui_extra_networks +from modules import infotext, images, sysinfo, errors, ui_extra_networks class UserMetadataEditor: @@ -134,7 +134,7 @@ class UserMetadataEditor: basename, ext = os.path.splitext(filename) with open(basename + '.json', "w", encoding="utf8") as file: - json.dump(metadata, file, indent=4) + json.dump(metadata, file, indent=4, ensure_ascii=False) def save_user_metadata(self, name, desc, notes): user_metadata = self.get_user_metadata(name) @@ -181,7 +181,7 @@ class UserMetadataEditor: index = len(gallery) - 1 if index >= len(gallery) else index img_info = gallery[index if index >= 0 else 0] - image = generation_parameters_copypaste.image_from_url_text(img_info) + image = infotext.image_from_url_text(img_info) geninfo, items = images.read_info_from_image(image) images.save_image_with_geninfo(image, geninfo, item["local_preview"]) diff --git a/modules/ui_gradio_extensions.py b/modules/ui_gradio_extensions.py index 0d368f8b2..a86c368ef 100644 --- a/modules/ui_gradio_extensions.py +++ b/modules/ui_gradio_extensions.py @@ -1,17 +1,12 @@ import os import gradio as gr -from modules import localization, shared, scripts -from modules.paths import script_path, data_path, cwd +from modules import localization, shared, scripts, util +from modules.paths import script_path, data_path def webpath(fn): - if fn.startswith(cwd): - web_path = os.path.relpath(fn, cwd) - else: - web_path = os.path.abspath(fn) - - return f'file={web_path}?{os.path.getmtime(fn)}' + return f'file={util.truncate_path(fn)}?{os.path.getmtime(fn)}' def javascript_html(): diff --git a/modules/ui_loadsave.py b/modules/ui_loadsave.py index eb20ff258..693ff75c5 100644 --- a/modules/ui_loadsave.py +++ b/modules/ui_loadsave.py @@ -141,10 +141,10 @@ class UiLoadsave: def write_to_file(self, current_ui_settings): with open(self.filename, "w", encoding="utf8") as file: - json.dump(current_ui_settings, file, indent=4) + json.dump(current_ui_settings, file, indent=4, ensure_ascii=False) def dump_defaults(self): - """saves default values to a file unless tjhe file is present and there was an error loading default values at start""" + """saves default values to a file unless the file is present and there was an error loading default values at start""" if self.error_loading and os.path.exists(self.filename): return diff --git a/modules/ui_postprocessing.py b/modules/ui_postprocessing.py index 802e1ce71..b74a15323 100644 --- a/modules/ui_postprocessing.py +++ b/modules/ui_postprocessing.py @@ -1,9 +1,10 @@ import gradio as gr -from modules import scripts, shared, ui_common, postprocessing, call_queue -import modules.generation_parameters_copypaste as parameters_copypaste +from modules import scripts, shared, ui_common, postprocessing, call_queue, ui_toprow +import modules.infotext as parameters_copypaste def create_ui(): + dummy_component = gr.Label(visible=False) tab_index = gr.State(value=0) with gr.Row(equal_height=False, variant='compact'): @@ -20,11 +21,13 @@ def create_ui(): extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir") show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results") - submit = gr.Button('Generate', elem_id="extras_generate", variant='primary') - script_inputs = scripts.scripts_postproc.setup_ui() with gr.Column(): + toprow = ui_toprow.Toprow(is_compact=True, is_img2img=False, id_part="extras") + toprow.create_inline_toprow_image() + submit = toprow.submit + result_images, html_info_x, html_info, html_log = ui_common.create_output_panel("extras", shared.opts.outdir_extras_samples) tab_single.select(fn=lambda: 0, inputs=[], outputs=[tab_index]) @@ -32,8 +35,10 @@ def create_ui(): tab_batch_dir.select(fn=lambda: 2, inputs=[], outputs=[tab_index]) submit.click( - fn=call_queue.wrap_gradio_gpu_call(postprocessing.run_postprocessing, extra_outputs=[None, '']), + fn=call_queue.wrap_gradio_gpu_call(postprocessing.run_postprocessing_webui, extra_outputs=[None, '']), + _js="submit_extras", inputs=[ + dummy_component, tab_index, extras_image, image_batch, @@ -45,8 +50,9 @@ def create_ui(): outputs=[ result_images, html_info_x, - html_info, - ] + html_log, + ], + show_progress=False, ) parameters_copypaste.add_paste_fields("extras", extras_image, None) diff --git a/modules/ui_prompt_styles.py b/modules/ui_prompt_styles.py index 3bcf092fd..0d74c23fa 100644 --- a/modules/ui_prompt_styles.py +++ b/modules/ui_prompt_styles.py @@ -68,10 +68,10 @@ class UiPromptStyles: self.copy = ui_components.ToolButton(value=styles_copy_symbol, elem_id=f"{tabname}_style_copy", tooltip="Copy main UI prompt to style.") with gr.Row(): - self.prompt = gr.Textbox(label="Prompt", show_label=True, elem_id=f"{tabname}_edit_style_prompt", lines=3) + self.prompt = gr.Textbox(label="Prompt", show_label=True, elem_id=f"{tabname}_edit_style_prompt", lines=3, elem_classes=["prompt"]) with gr.Row(): - self.neg_prompt = gr.Textbox(label="Negative prompt", show_label=True, elem_id=f"{tabname}_edit_style_neg_prompt", lines=3) + self.neg_prompt = gr.Textbox(label="Negative prompt", show_label=True, elem_id=f"{tabname}_edit_style_neg_prompt", lines=3, elem_classes=["prompt"]) with gr.Row(): self.save = gr.Button('Save', variant='primary', elem_id=f'{tabname}_edit_style_save', visible=False) diff --git a/modules/ui_toprow.py b/modules/ui_toprow.py new file mode 100644 index 000000000..1abc9117b --- /dev/null +++ b/modules/ui_toprow.py @@ -0,0 +1,149 @@ +import gradio as gr + +from modules import shared, ui_prompt_styles +import modules.images + +from modules.ui_components import ToolButton + + +class Toprow: + """Creates a top row UI with prompts, generate button, styles, extra little buttons for things, and enables some functionality related to their operation""" + + prompt = None + prompt_img = None + negative_prompt = None + + button_interrogate = None + button_deepbooru = None + + interrupt = None + skip = None + submit = None + + paste = None + clear_prompt_button = None + apply_styles = None + restore_progress_button = None + + token_counter = None + token_button = None + negative_token_counter = None + negative_token_button = None + + ui_styles = None + + submit_box = None + + def __init__(self, is_img2img, is_compact=False, id_part=None): + if id_part is None: + id_part = "img2img" if is_img2img else "txt2img" + + self.id_part = id_part + self.is_img2img = is_img2img + self.is_compact = is_compact + + if not is_compact: + with gr.Row(elem_id=f"{id_part}_toprow", variant="compact"): + self.create_classic_toprow() + else: + self.create_submit_box() + + def create_classic_toprow(self): + self.create_prompts() + + with gr.Column(scale=1, elem_id=f"{self.id_part}_actions_column"): + self.create_submit_box() + + self.create_tools_row() + + self.create_styles_ui() + + def create_inline_toprow_prompts(self): + if not self.is_compact: + return + + self.create_prompts() + + with gr.Row(elem_classes=["toprow-compact-stylerow"]): + with gr.Column(elem_classes=["toprow-compact-tools"]): + self.create_tools_row() + with gr.Column(): + self.create_styles_ui() + + def create_inline_toprow_image(self): + if not self.is_compact: + return + + self.submit_box.render() + + def create_prompts(self): + with gr.Column(elem_id=f"{self.id_part}_prompt_container", elem_classes=["prompt-container-compact"] if self.is_compact else [], scale=6): + with gr.Row(elem_id=f"{self.id_part}_prompt_row", elem_classes=["prompt-row"]): + self.prompt = gr.Textbox(label="Prompt", elem_id=f"{self.id_part}_prompt", show_label=False, lines=3, placeholder="Prompt\n(Press Ctrl+Enter to generate, Alt+Enter to skip, Esc to interrupt)", elem_classes=["prompt"]) + self.prompt_img = gr.File(label="", elem_id=f"{self.id_part}_prompt_image", file_count="single", type="binary", visible=False) + + with gr.Row(elem_id=f"{self.id_part}_neg_prompt_row", elem_classes=["prompt-row"]): + self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{self.id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt\n(Press Ctrl+Enter to generate, Alt+Enter to skip, Esc to interrupt)", elem_classes=["prompt"]) + + self.prompt_img.change( + fn=modules.images.image_data, + inputs=[self.prompt_img], + outputs=[self.prompt, self.prompt_img], + show_progress=False, + ) + + def create_submit_box(self): + with gr.Row(elem_id=f"{self.id_part}_generate_box", elem_classes=["generate-box"] + (["generate-box-compact"] if self.is_compact else []), render=not self.is_compact) as submit_box: + self.submit_box = submit_box + + self.interrupt = gr.Button('Interrupt', elem_id=f"{self.id_part}_interrupt", elem_classes="generate-box-interrupt") + self.skip = gr.Button('Skip', elem_id=f"{self.id_part}_skip", elem_classes="generate-box-skip") + self.submit = gr.Button('Generate', elem_id=f"{self.id_part}_generate", variant='primary') + + self.skip.click( + fn=lambda: shared.state.skip(), + inputs=[], + outputs=[], + ) + + def interrupt_function(): + if shared.state.job_count > 1 and shared.opts.interrupt_after_current: + shared.state.stop_generating() + else: + shared.state.interrupt() + + self.interrupt.click( + fn=interrupt_function, + inputs=[], + outputs=[], + ) + + def create_tools_row(self): + with gr.Row(elem_id=f"{self.id_part}_tools"): + from modules.ui import paste_symbol, clear_prompt_symbol, restore_progress_symbol + + self.paste = ToolButton(value=paste_symbol, elem_id="paste", tooltip="Read generation parameters from prompt or last generation if prompt is empty into user interface.") + self.clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{self.id_part}_clear_prompt", tooltip="Clear prompt") + self.apply_styles = ToolButton(value=ui_prompt_styles.styles_materialize_symbol, elem_id=f"{self.id_part}_style_apply", tooltip="Apply all selected styles to prompts.") + + if self.is_img2img: + self.button_interrogate = ToolButton('📎', tooltip='Interrogate CLIP - use CLIP neural network to create a text describing the image, and put it into the prompt field', elem_id="interrogate") + self.button_deepbooru = ToolButton('📦', tooltip='Interrogate DeepBooru - use DeepBooru neural network to create a text describing the image, and put it into the prompt field', elem_id="deepbooru") + + self.restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{self.id_part}_restore_progress", visible=False, tooltip="Restore progress") + + self.token_counter = gr.HTML(value="0/75", elem_id=f"{self.id_part}_token_counter", elem_classes=["token-counter"]) + self.token_button = gr.Button(visible=False, elem_id=f"{self.id_part}_token_button") + self.negative_token_counter = gr.HTML(value="0/75", elem_id=f"{self.id_part}_negative_token_counter", elem_classes=["token-counter"]) + self.negative_token_button = gr.Button(visible=False, elem_id=f"{self.id_part}_negative_token_button") + + self.clear_prompt_button.click( + fn=lambda *x: x, + _js="confirm_clear_prompt", + inputs=[self.prompt, self.negative_prompt], + outputs=[self.prompt, self.negative_prompt], + ) + + def create_styles_ui(self): + self.ui_styles = ui_prompt_styles.UiPromptStyles(self.id_part, self.prompt, self.negative_prompt) + self.ui_styles.setup_apply_button(self.apply_styles) diff --git a/modules/upscaler.py b/modules/upscaler.py index e682bbaa2..3aee69db8 100644 --- a/modules/upscaler.py +++ b/modules/upscaler.py @@ -57,6 +57,9 @@ class Upscaler: dest_h = int((img.height * scale) // 8 * 8) for _ in range(3): + if img.width >= dest_w and img.height >= dest_h: + break + shape = (img.width, img.height) img = self.do_upscale(img, selected_model) @@ -64,9 +67,6 @@ class Upscaler: if shape == (img.width, img.height): break - if img.width >= dest_w and img.height >= dest_h: - break - if img.width != dest_w or img.height != dest_h: img = img.resize((int(dest_w), int(dest_h)), resample=LANCZOS) @@ -98,6 +98,9 @@ class UpscalerData: self.scale = scale self.model = model + def __repr__(self): + return f"" + class UpscalerNone(Upscaler): name = "None" diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py new file mode 100644 index 000000000..f5cb92d51 --- /dev/null +++ b/modules/upscaler_utils.py @@ -0,0 +1,140 @@ +import logging +from typing import Callable + +import numpy as np +import torch +import tqdm +from PIL import Image + +from modules import images, shared, torch_utils + +logger = logging.getLogger(__name__) + + +def upscale_without_tiling(model, img: Image.Image): + img = np.array(img) + img = img[:, :, ::-1] + img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 + img = torch.from_numpy(img).float() + + param = torch_utils.get_param(model) + img = img.unsqueeze(0).to(device=param.device, dtype=param.dtype) + + with torch.no_grad(): + output = model(img) + + output = output.squeeze().float().cpu().clamp_(0, 1).numpy() + output = 255. * np.moveaxis(output, 0, 2) + output = output.astype(np.uint8) + output = output[:, :, ::-1] + return Image.fromarray(output, 'RGB') + + +def upscale_with_model( + model: Callable[[torch.Tensor], torch.Tensor], + img: Image.Image, + *, + tile_size: int, + tile_overlap: int = 0, + desc="tiled upscale", +) -> Image.Image: + if tile_size <= 0: + logger.debug("Upscaling %s without tiling", img) + output = upscale_without_tiling(model, img) + logger.debug("=> %s", output) + return output + + grid = images.split_grid(img, tile_size, tile_size, tile_overlap) + newtiles = [] + + with tqdm.tqdm(total=grid.tile_count, desc=desc) as p: + for y, h, row in grid.tiles: + newrow = [] + for x, w, tile in row: + logger.debug("Tile (%d, %d) %s...", x, y, tile) + output = upscale_without_tiling(model, tile) + scale_factor = output.width // tile.width + logger.debug("=> %s (scale factor %s)", output, scale_factor) + newrow.append([x * scale_factor, w * scale_factor, output]) + p.update(1) + newtiles.append([y * scale_factor, h * scale_factor, newrow]) + + newgrid = images.Grid( + newtiles, + tile_w=grid.tile_w * scale_factor, + tile_h=grid.tile_h * scale_factor, + image_w=grid.image_w * scale_factor, + image_h=grid.image_h * scale_factor, + overlap=grid.overlap * scale_factor, + ) + return images.combine_grid(newgrid) + + +def tiled_upscale_2( + img, + model, + *, + tile_size: int, + tile_overlap: int, + scale: int, + device, + desc="Tiled upscale", +): + # Alternative implementation of `upscale_with_model` originally used by + # SwinIR and ScuNET. It differs from `upscale_with_model` in that tiling and + # weighting is done in PyTorch space, as opposed to `images.Grid` doing it in + # Pillow space without weighting. + b, c, h, w = img.size() + tile_size = min(tile_size, h, w) + + if tile_size <= 0: + logger.debug("Upscaling %s without tiling", img.shape) + return model(img) + + stride = tile_size - tile_overlap + h_idx_list = list(range(0, h - tile_size, stride)) + [h - tile_size] + w_idx_list = list(range(0, w - tile_size, stride)) + [w - tile_size] + result = torch.zeros( + b, + c, + h * scale, + w * scale, + device=device, + ).type_as(img) + weights = torch.zeros_like(result) + logger.debug("Upscaling %s to %s with tiles", img.shape, result.shape) + with tqdm.tqdm(total=len(h_idx_list) * len(w_idx_list), desc=desc) as pbar: + for h_idx in h_idx_list: + if shared.state.interrupted or shared.state.skipped: + break + + for w_idx in w_idx_list: + if shared.state.interrupted or shared.state.skipped: + break + + in_patch = img[ + ..., + h_idx : h_idx + tile_size, + w_idx : w_idx + tile_size, + ] + out_patch = model(in_patch) + + result[ + ..., + h_idx * scale : (h_idx + tile_size) * scale, + w_idx * scale : (w_idx + tile_size) * scale, + ].add_(out_patch) + + out_patch_mask = torch.ones_like(out_patch) + + weights[ + ..., + h_idx * scale : (h_idx + tile_size) * scale, + w_idx * scale : (w_idx + tile_size) * scale, + ].add_(out_patch_mask) + + pbar.update(1) + + output = result.div_(weights) + + return output diff --git a/modules/util.py b/modules/util.py index 60afc0670..4861bcb08 100644 --- a/modules/util.py +++ b/modules/util.py @@ -2,7 +2,7 @@ import os import re from modules import shared -from modules.paths_internal import script_path +from modules.paths_internal import script_path, cwd def natural_sort_key(s, regex=re.compile('([0-9]+)')): @@ -56,3 +56,13 @@ def ldm_print(*args, **kwargs): return print(*args, **kwargs) + + +def truncate_path(target_path, base_path=cwd): + abs_target, abs_base = os.path.abspath(target_path), os.path.abspath(base_path) + try: + if os.path.commonpath([abs_target, abs_base]) == abs_base: + return os.path.relpath(abs_target, abs_base) + except ValueError: + pass + return abs_target diff --git a/modules/xlmr.py b/modules/xlmr.py index a407a3cad..319771b7b 100644 --- a/modules/xlmr.py +++ b/modules/xlmr.py @@ -5,6 +5,9 @@ from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRoberta from transformers import XLMRobertaModel,XLMRobertaTokenizer from typing import Optional +from modules import torch_utils + + class BertSeriesConfig(BertConfig): def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs): @@ -62,7 +65,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel): self.post_init() def encode(self,c): - device = next(self.parameters()).device + device = torch_utils.get_param(self).device text = self.tokenizer(c, truncation=True, max_length=77, diff --git a/modules/xlmr_m18.py b/modules/xlmr_m18.py index a727e8655..f60555049 100644 --- a/modules/xlmr_m18.py +++ b/modules/xlmr_m18.py @@ -4,6 +4,8 @@ import torch from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig from transformers import XLMRobertaModel,XLMRobertaTokenizer from typing import Optional +from modules import torch_utils + class BertSeriesConfig(BertConfig): def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs): @@ -68,7 +70,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel): self.post_init() def encode(self,c): - device = next(self.parameters()).device + device = torch_utils.get_param(self).device text = self.tokenizer(c, truncation=True, max_length=77, diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py new file mode 100644 index 000000000..f7687a66c --- /dev/null +++ b/modules/xpu_specific.py @@ -0,0 +1,124 @@ +from modules import shared +from modules.sd_hijack_utils import CondFunc + +has_ipex = False +try: + import torch + import intel_extension_for_pytorch as ipex # noqa: F401 + has_ipex = True +except Exception: + pass + + +def check_for_xpu(): + return has_ipex and hasattr(torch, 'xpu') and torch.xpu.is_available() + + +def get_xpu_device_string(): + if shared.cmd_opts.device_id is not None: + return f"xpu:{shared.cmd_opts.device_id}" + return "xpu" + + +def torch_xpu_gc(): + with torch.xpu.device(get_xpu_device_string()): + torch.xpu.empty_cache() + + +has_xpu = check_for_xpu() + + +# Arc GPU cannot allocate a single block larger than 4GB: https://github.com/intel/compute-runtime/issues/627 +# Here we implement a slicing algorithm to split large batch size into smaller chunks, +# so that SDPA of each chunk wouldn't require any allocation larger than ARC_SINGLE_ALLOCATION_LIMIT. +# The heuristic limit (TOTAL_VRAM // 8) is tuned for Intel Arc A770 16G and Arc A750 8G, +# which is the best trade-off between VRAM usage and performance. +ARC_SINGLE_ALLOCATION_LIMIT = {} +orig_sdp_attn_func = torch.nn.functional.scaled_dot_product_attention +def torch_xpu_scaled_dot_product_attention( + query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, *args, **kwargs +): + # cast to same dtype first + key = key.to(query.dtype) + value = value.to(query.dtype) + + N = query.shape[:-2] # Batch size + L = query.size(-2) # Target sequence length + E = query.size(-1) # Embedding dimension of the query and key + S = key.size(-2) # Source sequence length + Ev = value.size(-1) # Embedding dimension of the value + + total_batch_size = torch.numel(torch.empty(N)) + device_id = query.device.index + if device_id not in ARC_SINGLE_ALLOCATION_LIMIT: + ARC_SINGLE_ALLOCATION_LIMIT[device_id] = min(torch.xpu.get_device_properties(device_id).total_memory // 8, 4 * 1024 * 1024 * 1024) + batch_size_limit = max(1, ARC_SINGLE_ALLOCATION_LIMIT[device_id] // (L * S * query.element_size())) + + if total_batch_size <= batch_size_limit: + return orig_sdp_attn_func( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + *args, **kwargs + ) + + query = torch.reshape(query, (-1, L, E)) + key = torch.reshape(key, (-1, S, E)) + value = torch.reshape(value, (-1, S, Ev)) + if attn_mask is not None: + attn_mask = attn_mask.view(-1, L, S) + chunk_count = (total_batch_size + batch_size_limit - 1) // batch_size_limit + outputs = [] + for i in range(chunk_count): + attn_mask_chunk = ( + None + if attn_mask is None + else attn_mask[i * batch_size_limit : (i + 1) * batch_size_limit, :, :] + ) + chunk_output = orig_sdp_attn_func( + query[i * batch_size_limit : (i + 1) * batch_size_limit, :, :], + key[i * batch_size_limit : (i + 1) * batch_size_limit, :, :], + value[i * batch_size_limit : (i + 1) * batch_size_limit, :, :], + attn_mask_chunk, + dropout_p, + is_causal, + *args, **kwargs + ) + outputs.append(chunk_output) + result = torch.cat(outputs, dim=0) + return torch.reshape(result, (*N, L, Ev)) + + +if has_xpu: + # W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device + CondFunc('torch.Generator', + lambda orig_func, device=None: torch.xpu.Generator(device), + lambda orig_func, device=None: device is not None and device.type == "xpu") + + # W/A for some OPs that could not handle different input dtypes + CondFunc('torch.nn.functional.layer_norm', + lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: + orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs), + lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: + weight is not None and input.dtype != weight.data.dtype) + CondFunc('torch.nn.modules.GroupNorm.forward', + lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), + lambda orig_func, self, input: input.dtype != self.weight.data.dtype) + CondFunc('torch.nn.modules.linear.Linear.forward', + lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), + lambda orig_func, self, input: input.dtype != self.weight.data.dtype) + CondFunc('torch.nn.modules.conv.Conv2d.forward', + lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), + lambda orig_func, self, input: input.dtype != self.weight.data.dtype) + CondFunc('torch.bmm', + lambda orig_func, input, mat2, out=None: orig_func(input.to(mat2.dtype), mat2, out=out), + lambda orig_func, input, mat2, out=None: input.dtype != mat2.dtype) + CondFunc('torch.cat', + lambda orig_func, tensors, dim=0, out=None: orig_func([t.to(tensors[0].dtype) for t in tensors], dim=dim, out=out), + lambda orig_func, tensors, dim=0, out=None: not all(t.dtype == tensors[0].dtype for t in tensors)) + CondFunc('torch.nn.functional.scaled_dot_product_attention', + lambda orig_func, *args, **kwargs: torch_xpu_scaled_dot_product_attention(*args, **kwargs), + lambda orig_func, query, *args, **kwargs: query.is_xpu) diff --git a/pyproject.toml b/pyproject.toml index 80541a8f3..d03036e7d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ exclude = [ ignore = [ "E501", # Line too long + "E721", # Do not compare types, use `isinstance` "E731", # Do not assign a `lambda` expression, use a `def` "I001", # Import block is un-sorted or un-formatted diff --git a/requirements.txt b/requirements.txt index 80b438455..731a1be7d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,12 +2,11 @@ GitPython Pillow accelerate -basicsr blendmodes clean-fid einops +facexlib fastapi>=0.90.1 -gfpgan gradio==3.41.2 inflection jsonmerge @@ -20,13 +19,11 @@ open-clip-torch piexif psutil pytorch_lightning -realesrgan requests resize-right safetensors scikit-image>=0.19 -timm tomesd torch torchdiffeq diff --git a/requirements_versions.txt b/requirements_versions.txt index 7d27f2be3..2a922f288 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -1,31 +1,30 @@ GitPython==3.1.32 Pillow==9.5.0 accelerate==0.21.0 -basicsr==1.4.2 blendmodes==2022 clean-fid==0.1.35 einops==0.4.1 +facexlib==0.3.0 fastapi==0.94.0 -gfpgan==1.3.8 gradio==3.41.2 httpcore==0.15 inflection==0.5.1 jsonmerge==1.8.0 kornia==0.6.7 lark==1.1.2 -numpy==1.23.5 +numpy==1.26.2 omegaconf==2.2.3 open-clip-torch==2.20.0 piexif==1.1.3 psutil==5.9.5 pytorch_lightning==1.9.4 -realesrgan==0.3.0 resize-right==0.0.2 safetensors==0.3.1 scikit-image==0.21.0 -timm==0.9.2 +spandrel==0.1.6 tomesd==0.1.3 torch torchdiffeq==0.2.3 torchsde==0.2.6 transformers==4.30.2 +httpx==0.24.1 diff --git a/script.js b/script.js index 5f6ee3542..be1bc317e 100644 --- a/script.js +++ b/script.js @@ -121,26 +121,56 @@ document.addEventListener("DOMContentLoaded", function() { }); /** - * Add a ctrl+enter as a shortcut to start a generation + * Add keyboard shortcuts: + * Ctrl+Enter to start/restart a generation + * Alt/Option+Enter to skip a generation + * Esc to interrupt a generation */ document.addEventListener('keydown', function(e) { const isEnter = e.key === 'Enter' || e.keyCode === 13; - const isModifierKey = e.metaKey || e.ctrlKey || e.altKey; + const isCtrlKey = e.metaKey || e.ctrlKey; + const isAltKey = e.altKey; + const isEsc = e.key === 'Escape'; - const interruptButton = get_uiCurrentTabContent().querySelector('button[id$=_interrupt]'); const generateButton = get_uiCurrentTabContent().querySelector('button[id$=_generate]'); + const interruptButton = get_uiCurrentTabContent().querySelector('button[id$=_interrupt]'); + const skipButton = get_uiCurrentTabContent().querySelector('button[id$=_skip]'); - if (isEnter && isModifierKey) { + if (isCtrlKey && isEnter) { if (interruptButton.style.display === 'block') { interruptButton.click(); - setTimeout(function() { - generateButton.click(); - }, 500); + const callback = (mutationList) => { + for (const mutation of mutationList) { + if (mutation.type === 'attributes' && mutation.attributeName === 'style') { + if (interruptButton.style.display === 'none') { + generateButton.click(); + observer.disconnect(); + } + } + } + }; + const observer = new MutationObserver(callback); + observer.observe(interruptButton, {attributes: true}); } else { generateButton.click(); } e.preventDefault(); } + + if (isAltKey && isEnter) { + skipButton.click(); + e.preventDefault(); + } + + if (isEsc) { + const globalPopup = document.querySelector('.global-popup'); + const lightboxModal = document.querySelector('#lightboxModal'); + if (!globalPopup || globalPopup.style.display === 'none') { + if (document.activeElement === lightboxModal) return; + interruptButton.click(); + e.preventDefault(); + } + } }); /** diff --git a/scripts/loopback.py b/scripts/loopback.py index 2d5feaf9b..800ee882a 100644 --- a/scripts/loopback.py +++ b/scripts/loopback.py @@ -95,7 +95,7 @@ class Script(scripts.Script): processed = processing.process_images(p) # Generation cancelled. - if state.interrupted: + if state.interrupted or state.stopping_generation: break if initial_seed is None: @@ -122,8 +122,8 @@ class Script(scripts.Script): p.inpainting_fill = original_inpainting_fill - if state.interrupted: - break + if state.interrupted or state.stopping_generation: + break if len(history) > 1: grid = images.image_grid(history, rows=1) diff --git a/scripts/postprocessing_caption.py b/scripts/postprocessing_caption.py new file mode 100644 index 000000000..5592a8987 --- /dev/null +++ b/scripts/postprocessing_caption.py @@ -0,0 +1,30 @@ +from modules import scripts_postprocessing, ui_components, deepbooru, shared +import gradio as gr + + +class ScriptPostprocessingCeption(scripts_postprocessing.ScriptPostprocessing): + name = "Caption" + order = 4040 + + def ui(self): + with ui_components.InputAccordion(False, label="Caption") as enable: + option = gr.CheckboxGroup(value=["Deepbooru"], choices=["Deepbooru", "BLIP"], show_label=False) + + return { + "enable": enable, + "option": option, + } + + def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, option): + if not enable: + return + + captions = [pp.caption] + + if "Deepbooru" in option: + captions.append(deepbooru.model.tag(pp.image)) + + if "BLIP" in option: + captions.append(shared.interrogator.interrogate(pp.image.convert("RGB"))) + + pp.caption = ", ".join([x for x in captions if x]) diff --git a/scripts/postprocessing_codeformer.py b/scripts/postprocessing_codeformer.py index a7d80d40e..e1e156ddc 100644 --- a/scripts/postprocessing_codeformer.py +++ b/scripts/postprocessing_codeformer.py @@ -1,28 +1,28 @@ from PIL import Image import numpy as np -from modules import scripts_postprocessing, codeformer_model +from modules import scripts_postprocessing, codeformer_model, ui_components import gradio as gr -from modules.ui_components import FormRow - class ScriptPostprocessingCodeFormer(scripts_postprocessing.ScriptPostprocessing): name = "CodeFormer" order = 3000 def ui(self): - with FormRow(): - codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, elem_id="extras_codeformer_visibility") - codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, elem_id="extras_codeformer_weight") + with ui_components.InputAccordion(False, label="CodeFormer") as enable: + with gr.Row(): + codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Visibility", value=1.0, elem_id="extras_codeformer_visibility") + codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Weight (0 = maximum effect, 1 = minimum effect)", value=0, elem_id="extras_codeformer_weight") return { + "enable": enable, "codeformer_visibility": codeformer_visibility, "codeformer_weight": codeformer_weight, } - def process(self, pp: scripts_postprocessing.PostprocessedImage, codeformer_visibility, codeformer_weight): - if codeformer_visibility == 0: + def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, codeformer_visibility, codeformer_weight): + if codeformer_visibility == 0 or not enable: return restored_img = codeformer_model.codeformer.restore(np.array(pp.image, dtype=np.uint8), w=codeformer_weight) diff --git a/scripts/postprocessing_create_flipped_copies.py b/scripts/postprocessing_create_flipped_copies.py new file mode 100644 index 000000000..b673003b6 --- /dev/null +++ b/scripts/postprocessing_create_flipped_copies.py @@ -0,0 +1,32 @@ +from PIL import ImageOps, Image + +from modules import scripts_postprocessing, ui_components +import gradio as gr + + +class ScriptPostprocessingCreateFlippedCopies(scripts_postprocessing.ScriptPostprocessing): + name = "Create flipped copies" + order = 4030 + + def ui(self): + with ui_components.InputAccordion(False, label="Create flipped copies") as enable: + with gr.Row(): + option = gr.CheckboxGroup(value=["Horizontal"], choices=["Horizontal", "Vertical", "Both"], show_label=False) + + return { + "enable": enable, + "option": option, + } + + def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, option): + if not enable: + return + + if "Horizontal" in option: + pp.extra_images.append(ImageOps.mirror(pp.image)) + + if "Vertical" in option: + pp.extra_images.append(pp.image.transpose(Image.Transpose.FLIP_TOP_BOTTOM)) + + if "Both" in option: + pp.extra_images.append(pp.image.transpose(Image.Transpose.FLIP_TOP_BOTTOM).transpose(Image.Transpose.FLIP_LEFT_RIGHT)) diff --git a/scripts/postprocessing_focal_crop.py b/scripts/postprocessing_focal_crop.py new file mode 100644 index 000000000..cff1dbc54 --- /dev/null +++ b/scripts/postprocessing_focal_crop.py @@ -0,0 +1,54 @@ + +from modules import scripts_postprocessing, ui_components, errors +import gradio as gr + +from modules.textual_inversion import autocrop + + +class ScriptPostprocessingFocalCrop(scripts_postprocessing.ScriptPostprocessing): + name = "Auto focal point crop" + order = 4010 + + def ui(self): + with ui_components.InputAccordion(False, label="Auto focal point crop") as enable: + face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05, elem_id="postprocess_focal_crop_face_weight") + entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="postprocess_focal_crop_entropy_weight") + edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="postprocess_focal_crop_edges_weight") + debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug") + + return { + "enable": enable, + "face_weight": face_weight, + "entropy_weight": entropy_weight, + "edges_weight": edges_weight, + "debug": debug, + } + + def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, face_weight, entropy_weight, edges_weight, debug): + if not enable: + return + + if not pp.shared.target_width or not pp.shared.target_height: + return + + dnn_model_path = None + try: + dnn_model_path = autocrop.download_and_cache_models() + except Exception: + errors.report("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", exc_info=True) + + autocrop_settings = autocrop.Settings( + crop_width=pp.shared.target_width, + crop_height=pp.shared.target_height, + face_points_weight=face_weight, + entropy_points_weight=entropy_weight, + corner_points_weight=edges_weight, + annotate_image=debug, + dnn_model_path=dnn_model_path, + ) + + result, *others = autocrop.crop_image(pp.image, autocrop_settings) + + pp.image = result + pp.extra_images = [pp.create_copy(x, nametags=["focal-crop-debug"], disable_processing=True) for x in others] + diff --git a/scripts/postprocessing_gfpgan.py b/scripts/postprocessing_gfpgan.py index d854f3f77..6e7566055 100644 --- a/scripts/postprocessing_gfpgan.py +++ b/scripts/postprocessing_gfpgan.py @@ -1,26 +1,25 @@ from PIL import Image import numpy as np -from modules import scripts_postprocessing, gfpgan_model +from modules import scripts_postprocessing, gfpgan_model, ui_components import gradio as gr -from modules.ui_components import FormRow - class ScriptPostprocessingGfpGan(scripts_postprocessing.ScriptPostprocessing): name = "GFPGAN" order = 2000 def ui(self): - with FormRow(): - gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, elem_id="extras_gfpgan_visibility") + with ui_components.InputAccordion(False, label="GFPGAN") as enable: + gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Visibility", value=1.0, elem_id="extras_gfpgan_visibility") return { + "enable": enable, "gfpgan_visibility": gfpgan_visibility, } - def process(self, pp: scripts_postprocessing.PostprocessedImage, gfpgan_visibility): - if gfpgan_visibility == 0: + def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, gfpgan_visibility): + if gfpgan_visibility == 0 or not enable: return restored_img = gfpgan_model.gfpgan_fix_faces(np.array(pp.image, dtype=np.uint8)) diff --git a/scripts/postprocessing_split_oversized.py b/scripts/postprocessing_split_oversized.py new file mode 100644 index 000000000..c4a03160f --- /dev/null +++ b/scripts/postprocessing_split_oversized.py @@ -0,0 +1,71 @@ +import math + +from modules import scripts_postprocessing, ui_components +import gradio as gr + + +def split_pic(image, inverse_xy, width, height, overlap_ratio): + if inverse_xy: + from_w, from_h = image.height, image.width + to_w, to_h = height, width + else: + from_w, from_h = image.width, image.height + to_w, to_h = width, height + h = from_h * to_w // from_w + if inverse_xy: + image = image.resize((h, to_w)) + else: + image = image.resize((to_w, h)) + + split_count = math.ceil((h - to_h * overlap_ratio) / (to_h * (1.0 - overlap_ratio))) + y_step = (h - to_h) / (split_count - 1) + for i in range(split_count): + y = int(y_step * i) + if inverse_xy: + splitted = image.crop((y, 0, y + to_h, to_w)) + else: + splitted = image.crop((0, y, to_w, y + to_h)) + yield splitted + + +class ScriptPostprocessingSplitOversized(scripts_postprocessing.ScriptPostprocessing): + name = "Split oversized images" + order = 4000 + + def ui(self): + with ui_components.InputAccordion(False, label="Split oversized images") as enable: + with gr.Row(): + split_threshold = gr.Slider(label='Threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="postprocess_split_threshold") + overlap_ratio = gr.Slider(label='Overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05, elem_id="postprocess_overlap_ratio") + + return { + "enable": enable, + "split_threshold": split_threshold, + "overlap_ratio": overlap_ratio, + } + + def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, split_threshold, overlap_ratio): + if not enable: + return + + width = pp.shared.target_width + height = pp.shared.target_height + + if not width or not height: + return + + if pp.image.height > pp.image.width: + ratio = (pp.image.width * height) / (pp.image.height * width) + inverse_xy = False + else: + ratio = (pp.image.height * width) / (pp.image.width * height) + inverse_xy = True + + if ratio >= 1.0 and ratio > split_threshold: + return + + result, *others = split_pic(pp.image, inverse_xy, width, height, overlap_ratio) + + pp.image = result + pp.extra_images = [pp.create_copy(x) for x in others] + diff --git a/scripts/postprocessing_upscale.py b/scripts/postprocessing_upscale.py index eb42a29e5..ed709688d 100644 --- a/scripts/postprocessing_upscale.py +++ b/scripts/postprocessing_upscale.py @@ -81,6 +81,14 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing): return image + def process_firstpass(self, pp: scripts_postprocessing.PostprocessedImage, upscale_mode=1, upscale_by=2.0, upscale_to_width=None, upscale_to_height=None, upscale_crop=False, upscaler_1_name=None, upscaler_2_name=None, upscaler_2_visibility=0.0): + if upscale_mode == 1: + pp.shared.target_width = upscale_to_width + pp.shared.target_height = upscale_to_height + else: + pp.shared.target_width = int(pp.image.width * upscale_by) + pp.shared.target_height = int(pp.image.height * upscale_by) + def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_mode=1, upscale_by=2.0, upscale_to_width=None, upscale_to_height=None, upscale_crop=False, upscaler_1_name=None, upscaler_2_name=None, upscaler_2_visibility=0.0): if upscaler_1_name == "None": upscaler_1_name = None @@ -126,6 +134,10 @@ class ScriptPostprocessingUpscaleSimple(ScriptPostprocessingUpscale): "upscaler_name": upscaler_name, } + def process_firstpass(self, pp: scripts_postprocessing.PostprocessedImage, upscale_by=2.0, upscaler_name=None): + pp.shared.target_width = int(pp.image.width * upscale_by) + pp.shared.target_height = int(pp.image.height * upscale_by) + def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_by=2.0, upscaler_name=None): if upscaler_name is None or upscaler_name == "None": return diff --git a/scripts/processing_autosized_crop.py b/scripts/processing_autosized_crop.py new file mode 100644 index 000000000..7e6749898 --- /dev/null +++ b/scripts/processing_autosized_crop.py @@ -0,0 +1,64 @@ +from PIL import Image + +from modules import scripts_postprocessing, ui_components +import gradio as gr + + +def center_crop(image: Image, w: int, h: int): + iw, ih = image.size + if ih / h < iw / w: + sw = w * ih / h + box = (iw - sw) / 2, 0, iw - (iw - sw) / 2, ih + else: + sh = h * iw / w + box = 0, (ih - sh) / 2, iw, ih - (ih - sh) / 2 + return image.resize((w, h), Image.Resampling.LANCZOS, box) + + +def multicrop_pic(image: Image, mindim, maxdim, minarea, maxarea, objective, threshold): + iw, ih = image.size + err = lambda w, h: 1 - (lambda x: x if x < 1 else 1 / x)(iw / ih / (w / h)) + wh = max(((w, h) for w in range(mindim, maxdim + 1, 64) for h in range(mindim, maxdim + 1, 64) + if minarea <= w * h <= maxarea and err(w, h) <= threshold), + key=lambda wh: (wh[0] * wh[1], -err(*wh))[::1 if objective == 'Maximize area' else -1], + default=None + ) + return wh and center_crop(image, *wh) + + +class ScriptPostprocessingAutosizedCrop(scripts_postprocessing.ScriptPostprocessing): + name = "Auto-sized crop" + order = 4020 + + def ui(self): + with ui_components.InputAccordion(False, label="Auto-sized crop") as enable: + gr.Markdown('Each image is center-cropped with an automatically chosen width and height.') + with gr.Row(): + mindim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension lower bound", value=384, elem_id="postprocess_multicrop_mindim") + maxdim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension upper bound", value=768, elem_id="postprocess_multicrop_maxdim") + with gr.Row(): + minarea = gr.Slider(minimum=64 * 64, maximum=2048 * 2048, step=1, label="Area lower bound", value=64 * 64, elem_id="postprocess_multicrop_minarea") + maxarea = gr.Slider(minimum=64 * 64, maximum=2048 * 2048, step=1, label="Area upper bound", value=640 * 640, elem_id="postprocess_multicrop_maxarea") + with gr.Row(): + objective = gr.Radio(["Maximize area", "Minimize error"], value="Maximize area", label="Resizing objective", elem_id="postprocess_multicrop_objective") + threshold = gr.Slider(minimum=0, maximum=1, step=0.01, label="Error threshold", value=0.1, elem_id="postprocess_multicrop_threshold") + + return { + "enable": enable, + "mindim": mindim, + "maxdim": maxdim, + "minarea": minarea, + "maxarea": maxarea, + "objective": objective, + "threshold": threshold, + } + + def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, mindim, maxdim, minarea, maxarea, objective, threshold): + if not enable: + return + + cropped = multicrop_pic(pp.image, mindim, maxdim, minarea, maxarea, objective, threshold) + if cropped is not None: + pp.image = cropped + else: + print(f"skipped {pp.image.width}x{pp.image.height} image (can't find suitable size within error threshold)") diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py index ca73b2a56..a4a2f24dd 100644 --- a/scripts/prompts_from_file.py +++ b/scripts/prompts_from_file.py @@ -114,6 +114,7 @@ class Script(scripts.Script): def ui(self, is_img2img): checkbox_iterate = gr.Checkbox(label="Iterate seed every line", value=False, elem_id=self.elem_id("checkbox_iterate")) checkbox_iterate_batch = gr.Checkbox(label="Use same random seed for all lines", value=False, elem_id=self.elem_id("checkbox_iterate_batch")) + prompt_position = gr.Radio(["start", "end"], label="Insert prompts at the", elem_id=self.elem_id("prompt_position"), value="start") prompt_txt = gr.Textbox(label="List of prompt inputs", lines=1, elem_id=self.elem_id("prompt_txt")) file = gr.File(label="Upload prompt inputs", type='binary', elem_id=self.elem_id("file")) @@ -124,9 +125,9 @@ class Script(scripts.Script): # We don't shrink back to 1, because that causes the control to ignore [enter], and it may # be unclear to the user that shift-enter is needed. prompt_txt.change(lambda tb: gr.update(lines=7) if ("\n" in tb) else gr.update(lines=2), inputs=[prompt_txt], outputs=[prompt_txt], show_progress=False) - return [checkbox_iterate, checkbox_iterate_batch, prompt_txt] + return [checkbox_iterate, checkbox_iterate_batch, prompt_position, prompt_txt] - def run(self, p, checkbox_iterate, checkbox_iterate_batch, prompt_txt: str): + def run(self, p, checkbox_iterate, checkbox_iterate_batch, prompt_position, prompt_txt: str): lines = [x for x in (x.strip() for x in prompt_txt.splitlines()) if x] p.do_not_save_grid = True @@ -167,6 +168,18 @@ class Script(scripts.Script): else: setattr(copy_p, k, v) + if args.get("prompt") and p.prompt: + if prompt_position == "start": + copy_p.prompt = args.get("prompt") + " " + p.prompt + else: + copy_p.prompt = p.prompt + " " + args.get("prompt") + + if args.get("negative_prompt") and p.negative_prompt: + if prompt_position == "start": + copy_p.negative_prompt = args.get("negative_prompt") + " " + p.negative_prompt + else: + copy_p.negative_prompt = p.negative_prompt + " " + args.get("negative_prompt") + proc = process_images(copy_p) images += proc.images diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 0dc255bc4..2f385ebf2 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -270,6 +270,7 @@ axis_options = [ AxisOption("Refiner checkpoint", str, apply_field('refiner_checkpoint'), format_value=format_remove_path, confirm=confirm_checkpoints_or_none, cost=1.0, choices=lambda: ['None'] + sorted(sd_models.checkpoints_list, key=str.casefold)), AxisOption("Refiner switch at", float, apply_field('refiner_switch_at')), AxisOption("RNG source", str, apply_override("randn_source"), choices=lambda: ["GPU", "CPU", "NV"]), + AxisOption("FP8 mode", str, apply_override("fp8_storage"), cost=0.9, choices=lambda: ["Disable", "Enable for SDXL", "Enable"]), ] @@ -437,13 +438,16 @@ class Script(scripts.Script): with gr.Column(): draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend")) no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds")) + with gr.Row(): + vary_seeds_x = gr.Checkbox(label='Vary seeds for X', value=False, min_width=80, elem_id=self.elem_id("vary_seeds_x"), tooltip="Use different seeds for images along X axis.") + vary_seeds_y = gr.Checkbox(label='Vary seeds for Y', value=False, min_width=80, elem_id=self.elem_id("vary_seeds_y"), tooltip="Use different seeds for images along Y axis.") + vary_seeds_z = gr.Checkbox(label='Vary seeds for Z', value=False, min_width=80, elem_id=self.elem_id("vary_seeds_z"), tooltip="Use different seeds for images along Z axis.") with gr.Column(): include_lone_images = gr.Checkbox(label='Include Sub Images', value=False, elem_id=self.elem_id("include_lone_images")) include_sub_grids = gr.Checkbox(label='Include Sub Grids', value=False, elem_id=self.elem_id("include_sub_grids")) + csv_mode = gr.Checkbox(label='Use text inputs instead of dropdowns', value=False, elem_id=self.elem_id("csv_mode")) with gr.Column(): margin_size = gr.Slider(label="Grid margins (px)", minimum=0, maximum=500, value=0, step=2, elem_id=self.elem_id("margin_size")) - with gr.Column(): - csv_mode = gr.Checkbox(label='Use text inputs instead of dropdowns', value=False, elem_id=self.elem_id("csv_mode")) with gr.Row(variant="compact", elem_id="swap_axes"): swap_xy_axes_button = gr.Button(value="Swap X/Y axes", elem_id="xy_grid_swap_axes_button") @@ -475,6 +479,8 @@ class Script(scripts.Script): fill_z_button.click(fn=fill, inputs=[z_type, csv_mode], outputs=[z_values, z_values_dropdown]) def select_axis(axis_type, axis_values, axis_values_dropdown, csv_mode): + axis_type = axis_type or 0 # if axle type is None set to 0 + choices = self.current_axis_options[axis_type].choices has_choices = choices is not None @@ -522,9 +528,11 @@ class Script(scripts.Script): (z_values_dropdown, lambda params: get_dropdown_update_from_params("Z", params)), ) - return [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size, csv_mode] + return [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, vary_seeds_x, vary_seeds_y, vary_seeds_z, margin_size, csv_mode] + + def run(self, p, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, vary_seeds_x, vary_seeds_y, vary_seeds_z, margin_size, csv_mode): + x_type, y_type, z_type = x_type or 0, y_type or 0, z_type or 0 # if axle type is None set to 0 - def run(self, p, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size, csv_mode): if not no_fixed_seeds: modules.processing.fix_seed(p) @@ -688,7 +696,7 @@ class Script(scripts.Script): grid_infotext = [None] * (1 + len(zs)) def cell(x, y, z, ix, iy, iz): - if shared.state.interrupted: + if shared.state.interrupted or state.stopping_generation: return Processed(p, [], p.seed, "") pc = copy(p) @@ -697,6 +705,16 @@ class Script(scripts.Script): y_opt.apply(pc, y, ys) z_opt.apply(pc, z, zs) + xdim = len(xs) if vary_seeds_x else 1 + ydim = len(ys) if vary_seeds_y else 1 + + if vary_seeds_x: + pc.seed += ix + if vary_seeds_y: + pc.seed += iy * xdim + if vary_seeds_z: + pc.seed += iz * xdim * ydim + try: res = process_images(pc) except Exception as e: diff --git a/style.css b/style.css index 115626cd9..6d4c8a0d5 100644 --- a/style.css +++ b/style.css @@ -204,6 +204,11 @@ div.block.gradio-accordion { padding: 8px 8px; } +input[type="checkbox"].input-accordion-checkbox{ + vertical-align: sub; + margin-right: 0.5em; +} + /* txt2img/img2img specific */ @@ -291,6 +296,13 @@ div.block.gradio-accordion { min-height: 4.5em; } +#txt2img_generate, #img2img_generate { + min-height: 4.5em; +} +.generate-box-compact #txt2img_generate, .generate-box-compact #img2img_generate { + min-height: 3em; +} + @media screen and (min-width: 2500px) { #txt2img_gallery, #img2img_gallery { min-height: 768px; @@ -398,6 +410,15 @@ div#extras_scale_to_tab div.form{ min-width: 0.5em; } +div.toprow-compact-stylerow{ + margin: 0.5em 0; +} + +div.toprow-compact-tools{ + min-width: fit-content !important; + max-width: fit-content; +} + /* settings */ #quicksettings { align-items: end; @@ -441,6 +462,15 @@ div#extras_scale_to_tab div.form{ padding: 4px; } +#settings > div.tab-nav .settings-category{ + display: block; + margin: 1em 0 0.25em 0; + font-weight: bold; + text-decoration: underline; + cursor: default; + user-select: none; +} + #settings_result{ height: 1.4em; margin: 0 1.2em; @@ -520,7 +550,8 @@ table.popup-table .link{ height: 20px; background: #b4c0cc; border-radius: 3px !important; - top: -20px; + top: -14px; + left: 0px; width: 100%; } @@ -615,6 +646,8 @@ table.popup-table .link{ margin: auto; padding: 2em; z-index: 1001; + max-height: 90%; + max-width: 90%; } /* fullpage image viewer */ @@ -646,7 +679,7 @@ table.popup-table .link{ transition: 0.2s ease background-color; } .modalControls:hover { - background-color:rgba(0,0,0,0.9); + background-color:rgba(0,0,0, var(--sd-webui-modal-lightbox-toolbar-opacity)); } .modalClose { margin-left: auto; @@ -716,6 +749,22 @@ table.popup-table .link{ display: none; } +@media (pointer: fine) { + .modalPrev:hover, + .modalNext:hover, + .modalControls:hover ~ .modalPrev, + .modalControls:hover ~ .modalNext, + .modalControls:hover .cursor { + opacity: 1; + } + + .modalPrev, + .modalNext, + .modalControls .cursor { + opacity: var(--sd-webui-modal-lightbox-icon-opacity); + } +} + /* context menu (ie for the generate button) */ #context-menu{ @@ -818,6 +867,18 @@ footer { /* extra networks UI */ +.extra-page > div.gap{ + gap: 0; +} + +.extra-page-prompts{ + margin-bottom: 0; +} + +.extra-page-prompts.extra-page-prompts-active{ + margin-bottom: 1em; +} + .extra-network-cards{ height: calc(100vh - 24rem); overflow: clip scroll; diff --git a/test/conftest.py b/test/conftest.py index 31a5d9eaf..e4fc56785 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,10 +1,16 @@ +import base64 import os import pytest -import base64 - test_files_path = os.path.dirname(__file__) + "/test_files" +test_outputs_path = os.path.dirname(__file__) + "/test_outputs" + + +def pytest_configure(config): + # We don't want to fail on Py.test command line arguments being + # parsed by webui: + os.environ.setdefault("IGNORE_CMD_ARGS_ERRORS", "1") def file_to_base64(filename): @@ -23,3 +29,8 @@ def img2img_basic_image_base64() -> str: @pytest.fixture(scope="session") # session so we don't read this over and over def mask_basic_image_base64() -> str: return file_to_base64(os.path.join(test_files_path, "mask_basic.png")) + + +@pytest.fixture(scope="session") +def initialize() -> None: + import webui # noqa: F401 diff --git a/test/test_face_restorers.py b/test/test_face_restorers.py new file mode 100644 index 000000000..7760d51bf --- /dev/null +++ b/test/test_face_restorers.py @@ -0,0 +1,29 @@ +import os +from test.conftest import test_files_path, test_outputs_path + +import numpy as np +import pytest +from PIL import Image + + +@pytest.mark.usefixtures("initialize") +@pytest.mark.parametrize("restorer_name", ["gfpgan", "codeformer"]) +def test_face_restorers(restorer_name): + from modules import shared + + if restorer_name == "gfpgan": + from modules import gfpgan_model + gfpgan_model.setup_model(shared.cmd_opts.gfpgan_models_path) + restorer = gfpgan_model.gfpgan_fix_faces + elif restorer_name == "codeformer": + from modules import codeformer_model + codeformer_model.setup_model(shared.cmd_opts.codeformer_models_path) + restorer = codeformer_model.codeformer.restore + else: + raise NotImplementedError("...") + img = Image.open(os.path.join(test_files_path, "two-faces.jpg")) + np_img = np.array(img, dtype=np.uint8) + fixed_image = restorer(np_img) + assert fixed_image.shape == np_img.shape + assert not np.allclose(fixed_image, np_img) # should have visibly changed + Image.fromarray(fixed_image).save(os.path.join(test_outputs_path, f"{restorer_name}.png")) diff --git a/test/test_files/two-faces.jpg b/test/test_files/two-faces.jpg new file mode 100644 index 000000000..c9d1b0103 Binary files /dev/null and b/test/test_files/two-faces.jpg differ diff --git a/test/test_outputs/.gitkeep b/test/test_outputs/.gitkeep new file mode 100644 index 000000000..e69de29bb diff --git a/test/test_torch_utils.py b/test/test_torch_utils.py new file mode 100644 index 000000000..23ccb93a4 --- /dev/null +++ b/test/test_torch_utils.py @@ -0,0 +1,19 @@ +import types + +import pytest +import torch + +from modules import torch_utils + + +@pytest.mark.parametrize("wrapped", [True, False]) +def test_get_param(wrapped): + mod = torch.nn.Linear(1, 1) + cpu = torch.device("cpu") + mod.to(dtype=torch.float16, device=cpu) + if wrapped: + # more or less how spandrel wraps a thing + mod = types.SimpleNamespace(model=mod) + p = torch_utils.get_param(mod) + assert p.dtype == torch.float16 + assert p.device == cpu diff --git a/webui-macos-env.sh b/webui-macos-env.sh index 24bc5c426..db7e8b1a0 100644 --- a/webui-macos-env.sh +++ b/webui-macos-env.sh @@ -11,7 +11,7 @@ fi export install_dir="$HOME" export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --no-half-vae --use-cpu interrogate" -export TORCH_COMMAND="pip install torch==2.0.1 torchvision==0.15.2" +export TORCH_COMMAND="pip install torch==2.1.0 torchvision==0.16.0" export PYTORCH_ENABLE_MPS_FALLBACK=1 #################################################################### diff --git a/webui.py b/webui.py index 9ed20b306..2c417168a 100644 --- a/webui.py +++ b/webui.py @@ -39,7 +39,7 @@ def api_only(): print(f"Startup time: {startup_timer.summary()}.") api.launch( - server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", + server_name=initialize_util.gradio_server_name(), port=cmd_opts.port if cmd_opts.port else 7861, root_path=f"/{cmd_opts.subpath}" if cmd_opts.subpath else "" ) diff --git a/webui.sh b/webui.sh index 089114692..38258ef63 100755 --- a/webui.sh +++ b/webui.sh @@ -89,7 +89,7 @@ delimiter="################################################################" printf "\n%s\n" "${delimiter}" printf "\e[1m\e[32mInstall script for stable-diffusion + Web UI\n" -printf "\e[1m\e[34mTested on Debian 11 (Bullseye)\e[0m" +printf "\e[1m\e[34mTested on Debian 11 (Bullseye), Fedora 34+ and openSUSE Leap 15.4 or newer.\e[0m" printf "\n%s\n" "${delimiter}" # Do not run as root @@ -133,7 +133,7 @@ case "$gpu_info" in if [[ $(bc <<< "$pyv <= 3.10") -eq 1 ]] then # Navi users will still use torch 1.13 because 2.0 does not seem to work. - export TORCH_COMMAND="pip install torch==1.13.1+rocm5.2 torchvision==0.14.1+rocm5.2 --index-url https://download.pytorch.org/whl/rocm5.2" + export TORCH_COMMAND="pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm5.6" else printf "\e[1m\e[31mERROR: RX 5000 series GPUs must be using at max python 3.10, aborting...\e[0m" exit 1 @@ -143,8 +143,7 @@ case "$gpu_info" in *"Navi 2"*) export HSA_OVERRIDE_GFX_VERSION=10.3.0 ;; *"Navi 3"*) [[ -z "${TORCH_COMMAND}" ]] && \ - export TORCH_COMMAND="pip install torch torchvision --index-url https://download.pytorch.org/whl/test/rocm5.6" - # Navi 3 needs at least 5.5 which is only on the torch 2.1.0 release candidates right now + export TORCH_COMMAND="pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm5.7" ;; *"Renoir"*) export HSA_OVERRIDE_GFX_VERSION=9.0.0 printf "\n%s\n" "${delimiter}" @@ -223,13 +222,30 @@ fi # Try using TCMalloc on Linux prepare_tcmalloc() { if [[ "${OSTYPE}" == "linux"* ]] && [[ -z "${NO_TCMALLOC}" ]] && [[ -z "${LD_PRELOAD}" ]]; then - TCMALLOC="$(PATH=/usr/sbin:$PATH ldconfig -p | grep -Po "libtcmalloc(_minimal|)\.so\.\d" | head -n 1)" - if [[ ! -z "${TCMALLOC}" ]]; then - echo "Using TCMalloc: ${TCMALLOC}" - export LD_PRELOAD="${TCMALLOC}" - else - printf "\e[1m\e[31mCannot locate TCMalloc (improves CPU memory usage)\e[0m\n" - fi + # Define Tcmalloc Libs arrays + TCMALLOC_LIBS=("libtcmalloc(_minimal|)\.so\.\d" "libtcmalloc\.so\.\d") + + # Traversal array + for lib in "${TCMALLOC_LIBS[@]}" + do + #Determine which type of tcmalloc library the library supports + TCMALLOC="$(PATH=/usr/sbin:$PATH ldconfig -p | grep -P $lib | head -n 1)" + TC_INFO=(${TCMALLOC//=>/}) + if [[ ! -z "${TC_INFO}" ]]; then + echo "Using TCMalloc: ${TC_INFO}" + #Determine if the library is linked to libptthread and resolve undefined symbol: ptthread_Key_Create + if ldd ${TC_INFO[2]} | grep -q 'libpthread'; then + echo "$TC_INFO is linked with libpthread,execute LD_PRELOAD=${TC_INFO}" + export LD_PRELOAD="${TC_INFO}" + break + else + echo "$TC_INFO is not linked with libpthreadand will trigger undefined symbol: ptthread_Key_Create error" + fi + else + printf "\e[1m\e[31mCannot locate TCMalloc (improves CPU memory usage)\e[0m\n" + fi + done + fi }