changes for inpainting for #35

support for --medvram
attempt to support share
This commit is contained in:
AUTOMATIC 2022-09-01 11:41:42 +03:00
parent 3e4103541c
commit e1648fc1d1
2 changed files with 76 additions and 53 deletions

View File

@ -71,10 +71,10 @@ Run the command to start web ui:
python stable-diffusion-webui/webui.py
```
If you have a 4GB video card, run the command with `--lowvram` argument:
If you have a 4GB video card, run the command with either `--lowvram` or `--medvram` argument:
```
python stable-diffusion-webui/webui.py --lowvram
python stable-diffusion-webui/webui.py --medvram
```
After a while, you will get a message like this:
@ -280,17 +280,18 @@ print("Seed was: " + str(processed.seed))
display(processed.images, processed.seed, processed.info)
```
### `--lowvram`
### 4GB videocard support
Optimizations for GPUs with low VRAM. This should make it possible to generate 512x512 images on videocards with 4GB memory.
The original idea of those optimizations is by basujindal: https://github.com/basujindal/stable-diffusion. Model is separated into modules,
and only one module is kept in GPU memory; when another module needs to run, the previous is removed from GPU memory.
It should be obvious but the nature of those optimizations makes the processing run slower -- about 10 times slower
`--lowvram` is a reimplementation of optimization idea from by [basujindal](https://github.com/basujindal/stable-diffusion).
Model is separated into modules, and only one module is kept in GPU memory; when another module needs to run, the previous
is removed from GPU memory. The nature of this optimization makes the processing run slower -- about 10 times slower
compared to normal operation on my RTX 3090.
This is an independent implementation that does not require any modification to original Stable Diffusion code, and
with all code concenrated in one place rather than scattered around the program.
`--medvram` is another optimization that should reduce VRAM usage significantly by not peocessing conditional and
unconditional denoising in a same batch.
This implementation of optimization does not require any modification to original Stable Diffusion code.
### Inpainting
In img2img tab, draw a mask over a part of image, and that part will be in-painted.

110
webui.py
View File

@ -6,7 +6,10 @@ script_path = os.path.dirname(os.path.realpath(__file__))
sd_path = os.path.dirname(script_path)
# add parent directory to path; this is where Stable diffusion repo should be
path_dirs = [(sd_path, 'ldm', 'Stable Diffusion'), ('../../taming-transformers', 'taming', 'Taming Transformers')]
path_dirs = [
(sd_path, 'ldm', 'Stable Diffusion'),
('../../taming-transformers', 'taming', 'Taming Transformers')
]
for d, must_exist, what in path_dirs:
must_exist_path = os.path.abspath(os.path.join(script_path, d, must_exist))
if not os.path.exists(must_exist_path):
@ -38,15 +41,10 @@ from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
# fix gradio phoning home
gradio.utils.version_check = lambda: None
gradio.utils.get_local_ip_address = lambda: '127.0.0.1'
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI
mimetypes.init()
mimetypes.add_type('application/javascript', '.js')
# some of those options should not be changed at all because they would break the model, so I removed them from options.
opt_C = 4
opt_f = 8
@ -65,14 +63,21 @@ parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
parser.add_argument("--embeddings-dir", type=str, default='embeddings', help="embeddings dirtectory for textual inversion (default: embeddings)")
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
parser.add_argument("--lowvram", action='store_true', help="enamble stable diffusion model optimizations for low vram")
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrficing a little speed for low VRM usage")
parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrficing a lot of speed for very low VRM usage")
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)")
cmd_opts = parser.parse_args()
cpu = torch.device("cpu")
gpu = torch.device("cuda")
device = gpu if torch.cuda.is_available() else cpu
batch_cond_uncond = not (cmd_opts.lowvram or cmd_opts.medvram)
if not cmd_opts.share:
# fix gradio phoning home
gradio.utils.version_check = lambda: None
gradio.utils.get_local_ip_address = lambda: '127.0.0.1'
css_hide_progressbar = """
.wrap .m-12 svg { display:none!important; }
@ -294,21 +299,25 @@ def setup_for_low_vram(sd_model):
sd_model.first_stage_model.decode = lambda z, de=sd_model.first_stage_model.decode: first_stage_model_decode_wrap(sd_model.first_stage_model, de, z)
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
# the third remaining model is still too big for 4GB, so we also do the same for its submodules
# so that only one of them is in GPU at a time
diff_model = sd_model.model.diffusion_model
stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None
sd_model.model.to(device)
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored
if cmd_opts.medvram:
sd_model.model.register_forward_pre_hook(send_me_to_gpu)
else:
diff_model = sd_model.model.diffusion_model
# install hooks for bits of third model
diff_model.time_embed.register_forward_pre_hook(send_me_to_gpu)
for block in diff_model.input_blocks:
block.register_forward_pre_hook(send_me_to_gpu)
diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu)
for block in diff_model.output_blocks:
block.register_forward_pre_hook(send_me_to_gpu)
# the third remaining model is still too big for 4GB, so we also do the same for its submodules
# so that only one of them is in GPU at a time
stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None
sd_model.model.to(device)
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored
# install hooks for bits of third model
diff_model.time_embed.register_forward_pre_hook(send_me_to_gpu)
for block in diff_model.input_blocks:
block.register_forward_pre_hook(send_me_to_gpu)
diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu)
for block in diff_model.output_blocks:
block.register_forward_pre_hook(send_me_to_gpu)
def create_random_tensors(shape, seeds):
@ -860,7 +869,7 @@ class VanillaStableDiffusionSampler:
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning):
t_enc = int(min(p.denoising_strength, 0.999) * p.steps)
# existing code fail with cetin step counts, like 9
# existing code fails with cetin step counts, like 9
try:
self.sampler.make_schedule(ddim_num_steps=p.steps, verbose=False)
except Exception:
@ -887,13 +896,26 @@ class CFGDenoiser(nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model
self.mask = None
self.nmask = None
self.init_latent = None
def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
return uncond + (cond - uncond) * cond_scale
if batch_cond_uncond:
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
denoised = uncond + (cond - uncond) * cond_scale
else:
uncond = self.inner_model(x, sigma, cond=uncond)
cond = self.inner_model(x, sigma, cond=cond)
denoised = uncond + (cond - uncond) * cond_scale
if self.mask is not None:
denoised = self.init_latent * self.mask + self.nmask * denoised
return denoised
class KDiffusionSampler:
@ -910,19 +932,13 @@ class KDiffusionSampler:
xi = x + noise
if p.mask is not None:
if p.inpainting_fill == 2:
xi = xi * p.mask + noise * p.nmask
elif p.inpainting_fill == 3:
xi = xi * p.mask
sigma_sched = sigmas[p.steps - t_enc - 1:]
def mask_cb(v):
v["denoised"][:] = v["denoised"][:] * p.nmask + p.init_latent * p.mask
return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=mask_cb if p.mask is not None else None)
self.model_wrap_cfg.mask = p.mask
self.model_wrap_cfg.nmask = p.nmask
self.model_wrap_cfg.init_latent = p.init_latent
return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False)
def sample(self, p: StableDiffusionProcessing, x, conditioning, unconditional_conditioning):
sigmas = self.model_wrap.get_sigmas(p.steps)
@ -932,7 +948,7 @@ class KDiffusionSampler:
return samples_ddim
Processed = namedtuple('Processed', ['images','seed', 'info'])
Processed = namedtuple('Processed', ['images', 'seed', 'info'])
def process_images(p: StableDiffusionProcessing) -> Processed:
@ -1315,7 +1331,6 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
if self.mask_blur > 0:
self.image_mask = self.image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur)).convert('L')
if self.inpaint_full_res:
self.mask_for_overlay = self.image_mask
mask = self.image_mask.convert('L')
@ -1383,6 +1398,13 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.nmask = torch.asarray(latmask).to(device).type(sd_model.dtype)
def sample(self, x, conditioning, unconditional_conditioning):
if self.mask is not None:
if self.inpainting_fill == 2:
x = x * self.mask + create_random_tensors(x.shape[1:], [self.seed + x + 1 for x in range(x.shape[0])]) * self.nmask
elif self.inpainting_fill == 3:
x = x * self.mask
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning)
if self.mask is not None:
@ -1805,10 +1827,10 @@ sd_config = OmegaConf.load(cmd_opts.config)
sd_model = load_model_from_config(sd_config, cmd_opts.ckpt)
sd_model = (sd_model if cmd_opts.no_half else sd_model.half())
if not cmd_opts.lowvram:
sd_model = sd_model.to(device)
else:
if cmd_opts.lowvram or cmd_opts.medvram:
setup_for_low_vram(sd_model)
else:
sd_model = sd_model.to(device)
model_hijack = StableDiffusionModelHijack()
model_hijack.hijack(sd_model)
@ -1855,5 +1877,5 @@ def inject_gradio_html(javascript):
inject_gradio_html(javascript)
demo.queue(concurrency_count=1)
demo.launch()
demo.launch(share=cmd_opts.share)