From 125319988984987801dc4b4ab1e5ed36e9b211c5 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Fri, 10 Feb 2023 03:30:20 -0800 Subject: [PATCH 001/104] Working UniPC (for batch size 1) --- javascript/hints.js | 1 + modules/models/diffusion/uni_pc/__init__.py | 1 + modules/models/diffusion/uni_pc/sampler.py | 85 ++ modules/models/diffusion/uni_pc/uni_pc.py | 858 ++++++++++++++++++++ modules/processing.py | 2 +- modules/sd_samplers_compvis.py | 35 +- test/basic_features/txt2img_test.py | 2 + 7 files changed, 978 insertions(+), 6 deletions(-) create mode 100644 modules/models/diffusion/uni_pc/__init__.py create mode 100644 modules/models/diffusion/uni_pc/sampler.py create mode 100644 modules/models/diffusion/uni_pc/uni_pc.py diff --git a/javascript/hints.js b/javascript/hints.js index 9aa82f246..0a0620e39 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -6,6 +6,7 @@ titles = { "GFPGAN": "Restore low quality faces using GFPGAN neural network", "Euler a": "Euler Ancestral - very creative, each can get a completely different picture depending on step count, setting steps higher than 30-40 does not help", "DDIM": "Denoising Diffusion Implicit Models - best at inpainting", + "UniPC": "Unified Predictor-Corrector Framework for Fast Sampling of Diffusion Models", "DPM adaptive": "Ignores step count - uses a number of steps determined by the CFG and resolution", "Batch count": "How many batches of images to create", diff --git a/modules/models/diffusion/uni_pc/__init__.py b/modules/models/diffusion/uni_pc/__init__.py new file mode 100644 index 000000000..e1265e3fe --- /dev/null +++ b/modules/models/diffusion/uni_pc/__init__.py @@ -0,0 +1 @@ +from .sampler import UniPCSampler diff --git a/modules/models/diffusion/uni_pc/sampler.py b/modules/models/diffusion/uni_pc/sampler.py new file mode 100644 index 000000000..7cccd8a24 --- /dev/null +++ b/modules/models/diffusion/uni_pc/sampler.py @@ -0,0 +1,85 @@ +"""SAMPLING ONLY.""" + +import torch + +from .uni_pc import NoiseScheduleVP, model_wrapper, UniPC + +class UniPCSampler(object): + def __init__(self, model, **kwargs): + super().__init__() + self.model = model + to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) + self.before_sample = None + self.after_sample = None + self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def set_hooks(self, before, after): + self.before_sample = before + self.after_sample = after + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + + device = self.model.betas.device + if x_T is None: + img = torch.randn(size, device=device) + else: + img = x_T + + ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) + + model_fn = model_wrapper( + lambda x, t, c: self.model.apply_model(x, t, c), + ns, + model_type="noise", + guidance_type="classifier-free", + #condition=conditioning, + #unconditional_condition=unconditional_conditioning, + guidance_scale=unconditional_guidance_scale, + ) + + uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=self.before_sample, after_sample=self.after_sample) + x = uni_pc.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=3, lower_order_final=True) + + return x.to(device), None diff --git a/modules/models/diffusion/uni_pc/uni_pc.py b/modules/models/diffusion/uni_pc/uni_pc.py new file mode 100644 index 000000000..ec6b37da7 --- /dev/null +++ b/modules/models/diffusion/uni_pc/uni_pc.py @@ -0,0 +1,858 @@ +import torch +import torch.nn.functional as F +import math + + +class NoiseScheduleVP: + def __init__( + self, + schedule='discrete', + betas=None, + alphas_cumprod=None, + continuous_beta_0=0.1, + continuous_beta_1=20., + ): + """Create a wrapper class for the forward SDE (VP type). + + *** + Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t. + We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images. + *** + + The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ). + We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper). + Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have: + + log_alpha_t = self.marginal_log_mean_coeff(t) + sigma_t = self.marginal_std(t) + lambda_t = self.marginal_lambda(t) + + Moreover, as lambda(t) is an invertible function, we also support its inverse function: + + t = self.inverse_lambda(lambda_t) + + =============================================================== + + We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]). + + 1. For discrete-time DPMs: + + For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by: + t_i = (i + 1) / N + e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1. + We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3. + + Args: + betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details) + alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details) + + Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`. + + **Important**: Please pay special attention for the args for `alphas_cumprod`: + The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that + q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ). + Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have + alpha_{t_n} = \sqrt{\hat{alpha_n}}, + and + log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}). + + + 2. For continuous-time DPMs: + + We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise + schedule are the default settings in DDPM and improved-DDPM: + + Args: + beta_min: A `float` number. The smallest beta for the linear schedule. + beta_max: A `float` number. The largest beta for the linear schedule. + cosine_s: A `float` number. The hyperparameter in the cosine schedule. + cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule. + T: A `float` number. The ending time of the forward process. + + =============================================================== + + Args: + schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs, + 'linear' or 'cosine' for continuous-time DPMs. + Returns: + A wrapper object of the forward SDE (VP type). + + =============================================================== + + Example: + + # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', betas=betas) + + # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod) + + # For continuous-time DPMs (VPSDE), linear schedule: + >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.) + + """ + + if schedule not in ['discrete', 'linear', 'cosine']: + raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(schedule)) + + self.schedule = schedule + if schedule == 'discrete': + if betas is not None: + log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0) + else: + assert alphas_cumprod is not None + log_alphas = 0.5 * torch.log(alphas_cumprod) + self.total_N = len(log_alphas) + self.T = 1. + self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)) + self.log_alpha_array = log_alphas.reshape((1, -1,)) + else: + self.total_N = 1000 + self.beta_0 = continuous_beta_0 + self.beta_1 = continuous_beta_1 + self.cosine_s = 0.008 + self.cosine_beta_max = 999. + self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s + self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.)) + self.schedule = schedule + if schedule == 'cosine': + # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T. + # Note that T = 0.9946 may be not the optimal setting. However, we find it works well. + self.T = 0.9946 + else: + self.T = 1. + + def marginal_log_mean_coeff(self, t): + """ + Compute log(alpha_t) of a given continuous-time label t in [0, T]. + """ + if self.schedule == 'discrete': + return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)).reshape((-1)) + elif self.schedule == 'linear': + return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 + elif self.schedule == 'cosine': + log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.)) + log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0 + return log_alpha_t + + def marginal_alpha(self, t): + """ + Compute alpha_t of a given continuous-time label t in [0, T]. + """ + return torch.exp(self.marginal_log_mean_coeff(t)) + + def marginal_std(self, t): + """ + Compute sigma_t of a given continuous-time label t in [0, T]. + """ + return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t))) + + def marginal_lambda(self, t): + """ + Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T]. + """ + log_mean_coeff = self.marginal_log_mean_coeff(t) + log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff)) + return log_mean_coeff - log_std + + def inverse_lambda(self, lamb): + """ + Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t. + """ + if self.schedule == 'linear': + tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) + Delta = self.beta_0**2 + tmp + return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0) + elif self.schedule == 'discrete': + log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb) + t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.t_array.to(lamb.device), [1])) + return t.reshape((-1,)) + else: + log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) + t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s + t = t_fn(log_alpha) + return t + + +def model_wrapper( + model, + noise_schedule, + model_type="noise", + model_kwargs={}, + guidance_type="uncond", + #condition=None, + #unconditional_condition=None, + guidance_scale=1., + classifier_fn=None, + classifier_kwargs={}, +): + """Create a wrapper function for the noise prediction model. + + DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to + firstly wrap the model function to a noise prediction model that accepts the continuous time as the input. + + We support four types of the diffusion model by setting `model_type`: + + 1. "noise": noise prediction model. (Trained by predicting noise). + + 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0). + + 3. "v": velocity prediction model. (Trained by predicting the velocity). + The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2]. + + [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models." + arXiv preprint arXiv:2202.00512 (2022). + [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models." + arXiv preprint arXiv:2210.02303 (2022). + + 4. "score": marginal score function. (Trained by denoising score matching). + Note that the score function and the noise prediction model follows a simple relationship: + ``` + noise(x_t, t) = -sigma_t * score(x_t, t) + ``` + + We support three types of guided sampling by DPMs by setting `guidance_type`: + 1. "uncond": unconditional sampling by DPMs. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + + 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + + The input `classifier_fn` has the following format: + `` + classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond) + `` + + [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis," + in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794. + + 3. "classifier-free": classifier-free guidance sampling by conditional DPMs. + The input `model` has the following format: + `` + model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score + `` + And if cond == `unconditional_condition`, the model output is the unconditional DPM output. + + [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance." + arXiv preprint arXiv:2207.12598 (2022). + + + The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999) + or continuous-time labels (i.e. epsilon to T). + + We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise: + `` + def model_fn(x, t_continuous) -> noise: + t_input = get_model_input_time(t_continuous) + return noise_pred(model, x, t_input, **model_kwargs) + `` + where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver. + + =============================================================== + + Args: + model: A diffusion model with the corresponding format described above. + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + model_type: A `str`. The parameterization type of the diffusion model. + "noise" or "x_start" or "v" or "score". + model_kwargs: A `dict`. A dict for the other inputs of the model function. + guidance_type: A `str`. The type of the guidance for sampling. + "uncond" or "classifier" or "classifier-free". + condition: A pytorch tensor. The condition for the guided sampling. + Only used for "classifier" or "classifier-free" guidance type. + unconditional_condition: A pytorch tensor. The condition for the unconditional sampling. + Only used for "classifier-free" guidance type. + guidance_scale: A `float`. The scale for the guided sampling. + classifier_fn: A classifier function. Only used for the classifier guidance. + classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function. + Returns: + A noise prediction model that accepts the noised data and the continuous time as the inputs. + """ + + def get_model_input_time(t_continuous): + """ + Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time. + For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N]. + For continuous-time DPMs, we just use `t_continuous`. + """ + if noise_schedule.schedule == 'discrete': + return (t_continuous - 1. / noise_schedule.total_N) * 1000. + else: + return t_continuous + + def noise_pred_fn(x, t_continuous, cond=None): + if t_continuous.reshape((-1,)).shape[0] == 1: + t_continuous = t_continuous.expand((x.shape[0])) + t_input = get_model_input_time(t_continuous) + if cond is None: + output = model(x, t_input, None, **model_kwargs) + else: + output = model(x, t_input, cond, **model_kwargs) + if model_type == "noise": + return output + elif model_type == "x_start": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + dims = x.dim() + return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims) + elif model_type == "v": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + dims = x.dim() + return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x + elif model_type == "score": + sigma_t = noise_schedule.marginal_std(t_continuous) + dims = x.dim() + return -expand_dims(sigma_t, dims) * output + + def cond_grad_fn(x, t_input, condition): + """ + Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t). + """ + with torch.enable_grad(): + x_in = x.detach().requires_grad_(True) + log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs) + return torch.autograd.grad(log_prob.sum(), x_in)[0] + + def model_fn(x, t_continuous, condition, unconditional_condition): + """ + The noise predicition model function that is used for DPM-Solver. + """ + if t_continuous.reshape((-1,)).shape[0] == 1: + t_continuous = t_continuous.expand((x.shape[0])) + if guidance_type == "uncond": + return noise_pred_fn(x, t_continuous) + elif guidance_type == "classifier": + assert classifier_fn is not None + t_input = get_model_input_time(t_continuous) + cond_grad = cond_grad_fn(x, t_input, condition) + sigma_t = noise_schedule.marginal_std(t_continuous) + noise = noise_pred_fn(x, t_continuous) + return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad + elif guidance_type == "classifier-free": + if guidance_scale == 1. or unconditional_condition is None: + return noise_pred_fn(x, t_continuous, cond=condition) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t_continuous] * 2) + if isinstance(condition, dict): + assert isinstance(unconditional_condition, dict) + c_in = dict() + for k in condition: + if isinstance(condition[k], list): + c_in[k] = [torch.cat([ + unconditional_condition[k][i], + condition[k][i]]) for i in range(len(condition[k]))] + else: + c_in[k] = torch.cat([ + unconditional_condition[k], + condition[k]]) + elif isinstance(condition, list): + c_in = list() + assert isinstance(unconditional_condition, list) + for i in range(len(condition)): + c_in.append(torch.cat([unconditional_condition[i], condition[i]])) + else: + c_in = torch.cat([unconditional_condition, condition]) + noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) + return noise_uncond + guidance_scale * (noise - noise_uncond) + + assert model_type in ["noise", "x_start", "v"] + assert guidance_type in ["uncond", "classifier", "classifier-free"] + return model_fn + + +class UniPC: + def __init__( + self, + model_fn, + noise_schedule, + predict_x0=True, + thresholding=False, + max_val=1., + variant='bh1', + condition=None, + unconditional_condition=None, + before_sample=None, + after_sample=None + ): + """Construct a UniPC. + + We support both data_prediction and noise_prediction. + """ + self.model_fn_ = model_fn + self.noise_schedule = noise_schedule + self.variant = variant + self.predict_x0 = predict_x0 + self.thresholding = thresholding + self.max_val = max_val + self.condition = condition + self.unconditional_condition = unconditional_condition + self.before_sample = before_sample + self.after_sample = after_sample + + def dynamic_thresholding_fn(self, x0, t=None): + """ + The dynamic thresholding method. + """ + dims = x0.dim() + p = self.dynamic_thresholding_ratio + s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) + s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims) + x0 = torch.clamp(x0, -s, s) / s + return x0 + + def model(self, x, t): + cond = self.condition + uncond = self.unconditional_condition + if self.before_sample is not None: + x, t, cond, uncond = self.before_sample(x, t, cond, uncond) + res = self.model_fn_(x, t, cond, uncond) + if self.after_sample is not None: + x, t, cond, uncond, res = self.after_sample(x, t, cond, uncond, res) + + if isinstance(res, tuple): + # (None, pred_x0) + res = res[1] + + return res + + def noise_prediction_fn(self, x, t): + """ + Return the noise prediction model. + """ + return self.model(x, t) + + def data_prediction_fn(self, x, t): + """ + Return the data prediction model (with thresholding). + """ + noise = self.noise_prediction_fn(x, t) + dims = x.dim() + alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) + from pprint import pp + print("X:") + pp(x) + print("sigma_t:") + pp(sigma_t) + print("noise:") + pp(noise) + print("alpha_t:") + pp(alpha_t) + x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims) + if self.thresholding: + p = 0.995 # A hyperparameter in the paper of "Imagen" [1]. + s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) + s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims) + x0 = torch.clamp(x0, -s, s) / s + return x0 + + def model_fn(self, x, t): + """ + Convert the model to the noise prediction model or the data prediction model. + """ + if self.predict_x0: + return self.data_prediction_fn(x, t) + else: + return self.noise_prediction_fn(x, t) + + def get_time_steps(self, skip_type, t_T, t_0, N, device): + """Compute the intermediate time steps for sampling. + """ + if skip_type == 'logSNR': + lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device)) + lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device)) + logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device) + return self.noise_schedule.inverse_lambda(logSNR_steps) + elif skip_type == 'time_uniform': + return torch.linspace(t_T, t_0, N + 1).to(device) + elif skip_type == 'time_quadratic': + t_order = 2 + t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device) + return t + else: + raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)) + + def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device): + """ + Get the order of each step for sampling by the singlestep DPM-Solver. + """ + if order == 3: + K = steps // 3 + 1 + if steps % 3 == 0: + orders = [3,] * (K - 2) + [2, 1] + elif steps % 3 == 1: + orders = [3,] * (K - 1) + [1] + else: + orders = [3,] * (K - 1) + [2] + elif order == 2: + if steps % 2 == 0: + K = steps // 2 + orders = [2,] * K + else: + K = steps // 2 + 1 + orders = [2,] * (K - 1) + [1] + elif order == 1: + K = steps + orders = [1,] * steps + else: + raise ValueError("'order' must be '1' or '2' or '3'.") + if skip_type == 'logSNR': + # To reproduce the results in DPM-Solver paper + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device) + else: + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[torch.cumsum(torch.tensor([0,] + orders), 0).to(device)] + return timesteps_outer, orders + + def denoise_to_zero_fn(self, x, s): + """ + Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization. + """ + return self.data_prediction_fn(x, s) + + def multistep_uni_pc_update(self, x, model_prev_list, t_prev_list, t, order, **kwargs): + if len(t.shape) == 0: + t = t.view(-1) + if 'bh' in self.variant: + return self.multistep_uni_pc_bh_update(x, model_prev_list, t_prev_list, t, order, **kwargs) + else: + assert self.variant == 'vary_coeff' + return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs) + + def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True): + print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)') + ns = self.noise_schedule + assert order <= len(model_prev_list) + + # first compute rks + t_prev_0 = t_prev_list[-1] + lambda_prev_0 = ns.marginal_lambda(t_prev_0) + lambda_t = ns.marginal_lambda(t) + model_prev_0 = model_prev_list[-1] + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + log_alpha_t = ns.marginal_log_mean_coeff(t) + alpha_t = torch.exp(log_alpha_t) + + h = lambda_t - lambda_prev_0 + + rks = [] + D1s = [] + for i in range(1, order): + t_prev_i = t_prev_list[-(i + 1)] + model_prev_i = model_prev_list[-(i + 1)] + lambda_prev_i = ns.marginal_lambda(t_prev_i) + rk = (lambda_prev_i - lambda_prev_0) / h + rks.append(rk) + D1s.append((model_prev_i - model_prev_0) / rk) + + rks.append(1.) + rks = torch.tensor(rks, device=x.device) + + K = len(rks) + # build C matrix + C = [] + + col = torch.ones_like(rks) + for k in range(1, K + 1): + C.append(col) + col = col * rks / (k + 1) + C = torch.stack(C, dim=1) + + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) # (B, K) + C_inv_p = torch.linalg.inv(C[:-1, :-1]) + A_p = C_inv_p + + if use_corrector: + print('using corrector') + C_inv = torch.linalg.inv(C) + A_c = C_inv + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) + h_phi_ks = [] + factorial_k = 1 + h_phi_k = h_phi_1 + for k in range(1, K + 2): + h_phi_ks.append(h_phi_k) + h_phi_k = h_phi_k / hh - 1 / factorial_k + factorial_k *= (k + 1) + + model_t = None + if self.predict_x0: + x_t_ = ( + sigma_t / sigma_prev_0 * x + - alpha_t * h_phi_1 * model_prev_0 + ) + # now predictor + x_t = x_t_ + if len(D1s) > 0: + # compute the residuals for predictor + for k in range(K - 1): + x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k]) + # now corrector + if use_corrector: + model_t = self.model_fn(x_t, t) + D1_t = (model_t - model_prev_0) + x_t = x_t_ + k = 0 + for k in range(K - 1): + x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1]) + x_t = x_t - alpha_t * h_phi_ks[K] * (D1_t * A_c[k][-1]) + else: + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + x_t_ = ( + (torch.exp(log_alpha_t - log_alpha_prev_0)) * x + - (sigma_t * h_phi_1) * model_prev_0 + ) + # now predictor + x_t = x_t_ + if len(D1s) > 0: + # compute the residuals for predictor + for k in range(K - 1): + x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k]) + # now corrector + if use_corrector: + model_t = self.model_fn(x_t, t) + D1_t = (model_t - model_prev_0) + x_t = x_t_ + k = 0 + for k in range(K - 1): + x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1]) + x_t = x_t - sigma_t * h_phi_ks[K] * (D1_t * A_c[k][-1]) + return x_t, model_t + + def multistep_uni_pc_bh_update(self, x, model_prev_list, t_prev_list, t, order, x_t=None, use_corrector=True): + print(f'using unified predictor-corrector with order {order} (solver type: B(h))') + ns = self.noise_schedule + assert order <= len(model_prev_list) + dims = x.dim() + + # first compute rks + t_prev_0 = t_prev_list[-1] + lambda_prev_0 = ns.marginal_lambda(t_prev_0) + lambda_t = ns.marginal_lambda(t) + model_prev_0 = model_prev_list[-1] + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + alpha_t = torch.exp(log_alpha_t) + + h = lambda_t - lambda_prev_0 + + rks = [] + D1s = [] + for i in range(1, order): + t_prev_i = t_prev_list[-(i + 1)] + model_prev_i = model_prev_list[-(i + 1)] + lambda_prev_i = ns.marginal_lambda(t_prev_i) + rk = ((lambda_prev_i - lambda_prev_0) / h)[0] + rks.append(rk) + D1s.append((model_prev_i - model_prev_0) / rk) + + rks.append(1.) + rks = torch.tensor(rks, device=x.device) + + R = [] + b = [] + + hh = -h[0] if self.predict_x0 else h[0] + h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.variant == 'bh1': + B_h = hh + elif self.variant == 'bh2': + B_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= (i + 1) + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.tensor(b, device=x.device) + + # now predictor + use_predictor = len(D1s) > 0 and x_t is None + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) # (B, K) + if x_t is None: + # for order 2, we use a simplified version + if order == 2: + rhos_p = torch.tensor([0.5], device=b.device) + else: + rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]) + else: + D1s = None + + if use_corrector: + print('using corrector') + # for order 1, we use a simplified version + if order == 1: + rhos_c = torch.tensor([0.5], device=b.device) + else: + rhos_c = torch.linalg.solve(R, b) + + model_t = None + if self.predict_x0: + x_t_ = ( + expand_dims(sigma_t / sigma_prev_0, dims) * x + - expand_dims(alpha_t * h_phi_1, dims)* model_prev_0 + ) + + if x_t is None: + if use_predictor: + pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s) + else: + pred_res = 0 + x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * pred_res + + if use_corrector: + model_t = self.model_fn(x_t, t) + if D1s is not None: + corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = (model_t - model_prev_0) + x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t) + else: + x_t_ = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dimss) * x + - expand_dims(sigma_t * h_phi_1, dims) * model_prev_0 + ) + if x_t is None: + if use_predictor: + pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s) + else: + pred_res = 0 + x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * pred_res + + if use_corrector: + model_t = self.model_fn(x_t, t) + if D1s is not None: + corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = (model_t - model_prev_0) + x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t) + return x_t, model_t + + + def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform', + method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver', + atol=0.0078, rtol=0.05, corrector=False, + ): + t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end + t_T = self.noise_schedule.T if t_start is None else t_start + device = x.device + if method == 'multistep': + assert steps >= order + timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) + assert timesteps.shape[0] - 1 == steps + with torch.no_grad(): + vec_t = timesteps[0].expand((x.shape[0])) + model_prev_list = [self.model_fn(x, vec_t)] + t_prev_list = [vec_t] + # Init the first `order` values by lower order multistep DPM-Solver. + for init_order in range(1, order): + vec_t = timesteps[init_order].expand(x.shape[0]) + x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True) + if model_x is None: + model_x = self.model_fn(x, vec_t) + model_prev_list.append(model_x) + t_prev_list.append(vec_t) + for step in range(order, steps + 1): + vec_t = timesteps[step].expand(x.shape[0]) + if lower_order_final: + step_order = min(order, steps + 1 - step) + else: + step_order = order + print('this step order:', step_order) + if step == steps: + print('do not run corrector at the last step') + use_corrector = False + else: + use_corrector = True + x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector) + for i in range(order - 1): + t_prev_list[i] = t_prev_list[i + 1] + model_prev_list[i] = model_prev_list[i + 1] + t_prev_list[-1] = vec_t + # We do not need to evaluate the final model value. + if step < steps: + if model_x is None: + model_x = self.model_fn(x, vec_t) + model_prev_list[-1] = model_x + else: + raise NotImplementedError() + if denoise_to_zero: + x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0) + return x + + +############################################################# +# other utility functions +############################################################# + +def interpolate_fn(x, xp, yp): + """ + A piecewise linear function y = f(x), using xp and yp as keypoints. + We implement f(x) in a differentiable way (i.e. applicable for autograd). + The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.) + + Args: + x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver). + xp: PyTorch tensor with shape [C, K], where K is the number of keypoints. + yp: PyTorch tensor with shape [C, K]. + Returns: + The function values f(x), with shape [N, C]. + """ + N, K = x.shape[0], xp.shape[1] + all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2) + sorted_all_x, x_indices = torch.sort(all_x, dim=2) + x_idx = torch.argmin(x_indices, dim=2) + cand_start_idx = x_idx - 1 + start_idx = torch.where( + torch.eq(x_idx, 0), + torch.tensor(1, device=x.device), + torch.where( + torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + ), + ) + end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1) + start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2) + end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2) + start_idx2 = torch.where( + torch.eq(x_idx, 0), + torch.tensor(0, device=x.device), + torch.where( + torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + ), + ) + y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) + start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2) + end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2) + cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x) + return cand + + +def expand_dims(v, dims): + """ + Expand the tensor `v` to the dim `dims`. + + Args: + `v`: a PyTorch tensor with shape [N]. + `dim`: a `int`. + Returns: + a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`. + """ + return v[(...,) + (None,)*(dims - 1)] diff --git a/modules/processing.py b/modules/processing.py index e1b53ac0a..11e726dfe 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -884,7 +884,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): shared.state.nextjob() - img2img_sampler_name = self.sampler_name if self.sampler_name != 'PLMS' else 'DDIM' # PLMS does not support img2img so we just silently switch ot DDIM + img2img_sampler_name = 'DDIM' # PLMS does not support img2img so we just silently switch ot DDIM self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model) samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2] diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py index d03131cd4..86fa1c5be 100644 --- a/modules/sd_samplers_compvis.py +++ b/modules/sd_samplers_compvis.py @@ -7,19 +7,27 @@ import torch from modules.shared import state from modules import sd_samplers_common, prompt_parser, shared +import modules.models.diffusion.uni_pc samplers_data_compvis = [ sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}), sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}), + sd_samplers_common.SamplerData('UniPC', lambda model: VanillaStableDiffusionSampler(modules.models.diffusion.uni_pc.UniPCSampler, model), [], {}), ] class VanillaStableDiffusionSampler: def __init__(self, constructor, sd_model): self.sampler = constructor(sd_model) + self.is_ddim = hasattr(self.sampler, 'p_sample_ddim') self.is_plms = hasattr(self.sampler, 'p_sample_plms') - self.orig_p_sample_ddim = self.sampler.p_sample_plms if self.is_plms else self.sampler.p_sample_ddim + self.is_unipc = isinstance(self.sampler, modules.models.diffusion.uni_pc.UniPCSampler) + self.orig_p_sample_ddim = None + if self.is_plms: + self.orig_p_sample_ddim = self.sampler.p_sample_plms + elif self.is_ddim: + self.orig_p_sample_ddim = self.sampler.p_sample_ddim self.mask = None self.nmask = None self.init_latent = None @@ -45,6 +53,15 @@ class VanillaStableDiffusionSampler: return self.last_latent def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs): + x_dec, ts, cond, unconditional_conditioning = self.before_sample(x_dec, ts, cond, unconditional_conditioning) + + res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs) + + x_dec, ts, cond, unconditional_conditioning, res = self.after_sample(x_dec, ts, cond, unconditional_conditioning, res) + + return res + + def before_sample(self, x, ts, cond, unconditional_conditioning): if state.interrupted or state.skipped: raise sd_samplers_common.InterruptedException @@ -76,7 +93,7 @@ class VanillaStableDiffusionSampler: if self.mask is not None: img_orig = self.sampler.model.q_sample(self.init_latent, ts) - x_dec = img_orig * self.mask + self.nmask * x_dec + x = img_orig * self.mask + self.nmask * x # Wrap the image conditioning back up since the DDIM code can accept the dict directly. # Note that they need to be lists because it just concatenates them later. @@ -84,7 +101,13 @@ class VanillaStableDiffusionSampler: cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]} unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]} - res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs) + return x, ts, cond, unconditional_conditioning + + def after_sample(self, x, ts, cond, uncond, res): + if self.is_unipc: + # unipc model_fn returns (pred_x0) + # p_sample_ddim returns (x_prev, pred_x0) + res = (None, res[0]) if self.mask is not None: self.last_latent = self.init_latent * self.mask + self.nmask * res[1] @@ -97,7 +120,7 @@ class VanillaStableDiffusionSampler: state.sampling_step = self.step shared.total_tqdm.update() - return res + return x, ts, cond, uncond, res def initialize(self, p): self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim @@ -107,12 +130,14 @@ class VanillaStableDiffusionSampler: for fieldname in ['p_sample_ddim', 'p_sample_plms']: if hasattr(self.sampler, fieldname): setattr(self.sampler, fieldname, self.p_sample_ddim_hook) + if self.is_unipc: + self.sampler.set_hooks(lambda x, t, c, u: self.before_sample(x, t, c, u), lambda x, t, c, u, r: self.after_sample(x, t, c, u, r)) self.mask = p.mask if hasattr(p, 'mask') else None self.nmask = p.nmask if hasattr(p, 'nmask') else None def adjust_steps_if_invalid(self, p, num_steps): - if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'): + if ((self.config.name == 'DDIM' or self.config.name == "UniPC") and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'): valid_step = 999 / (1000 // num_steps) if valid_step == math.floor(valid_step): return int(valid_step) + 1 diff --git a/test/basic_features/txt2img_test.py b/test/basic_features/txt2img_test.py index 5aa43a44a..cb525fbb7 100644 --- a/test/basic_features/txt2img_test.py +++ b/test/basic_features/txt2img_test.py @@ -66,6 +66,8 @@ class TestTxt2ImgWorking(unittest.TestCase): self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) self.simple_txt2img["sampler_index"] = "DDIM" self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) + self.simple_txt2img["sampler_index"] = "UniPC" + self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) def test_txt2img_multiple_batches_performed(self): self.simple_txt2img["n_iter"] = 2 From 21880eb9e57b884635a07d2360831b4186afddf4 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Fri, 10 Feb 2023 04:47:08 -0800 Subject: [PATCH 002/104] Fix logspam and live previews --- modules/models/diffusion/uni_pc/sampler.py | 20 ++++++++++---- modules/models/diffusion/uni_pc/uni_pc.py | 32 ++++++++++------------ modules/sd_samplers_compvis.py | 20 ++++++++------ 3 files changed, 41 insertions(+), 31 deletions(-) diff --git a/modules/models/diffusion/uni_pc/sampler.py b/modules/models/diffusion/uni_pc/sampler.py index 7cccd8a24..219e9862c 100644 --- a/modules/models/diffusion/uni_pc/sampler.py +++ b/modules/models/diffusion/uni_pc/sampler.py @@ -19,9 +19,10 @@ class UniPCSampler(object): attr = attr.to(torch.device("cuda")) setattr(self, name, attr) - def set_hooks(self, before, after): - self.before_sample = before - self.after_sample = after + def set_hooks(self, before_sample, after_sample, after_update): + self.before_sample = before_sample + self.after_sample = after_sample + self.after_update = after_update @torch.no_grad() def sample(self, @@ -50,9 +51,17 @@ class UniPCSampler(object): ): if conditioning is not None: if isinstance(conditioning, dict): - cbs = conditioning[list(conditioning.keys())[0]].shape[0] + ctmp = conditioning[list(conditioning.keys())[0]] + while isinstance(ctmp, list): ctmp = ctmp[0] + cbs = ctmp.shape[0] if cbs != batch_size: print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + + elif isinstance(conditioning, list): + for ctmp in conditioning: + if ctmp.shape[0] != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: if conditioning.shape[0] != batch_size: print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") @@ -60,6 +69,7 @@ class UniPCSampler(object): # sampling C, H, W = shape size = (batch_size, C, H, W) + print(f'Data shape for UniPC sampling is {size}, eta {eta}') device = self.model.betas.device if x_T is None: @@ -79,7 +89,7 @@ class UniPCSampler(object): guidance_scale=unconditional_guidance_scale, ) - uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=self.before_sample, after_sample=self.after_sample) + uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=self.before_sample, after_sample=self.after_sample, after_update=self.after_update) x = uni_pc.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=3, lower_order_final=True) return x.to(device), None diff --git a/modules/models/diffusion/uni_pc/uni_pc.py b/modules/models/diffusion/uni_pc/uni_pc.py index ec6b37da7..31ee81a65 100644 --- a/modules/models/diffusion/uni_pc/uni_pc.py +++ b/modules/models/diffusion/uni_pc/uni_pc.py @@ -378,7 +378,8 @@ class UniPC: condition=None, unconditional_condition=None, before_sample=None, - after_sample=None + after_sample=None, + after_update=None ): """Construct a UniPC. @@ -394,6 +395,7 @@ class UniPC: self.unconditional_condition = unconditional_condition self.before_sample = before_sample self.after_sample = after_sample + self.after_update = after_update def dynamic_thresholding_fn(self, x0, t=None): """ @@ -434,15 +436,6 @@ class UniPC: noise = self.noise_prediction_fn(x, t) dims = x.dim() alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) - from pprint import pp - print("X:") - pp(x) - print("sigma_t:") - pp(sigma_t) - print("noise:") - pp(noise) - print("alpha_t:") - pp(alpha_t) x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims) if self.thresholding: p = 0.995 # A hyperparameter in the paper of "Imagen" [1]. @@ -524,7 +517,7 @@ class UniPC: return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs) def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True): - print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)') + #print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)') ns = self.noise_schedule assert order <= len(model_prev_list) @@ -568,7 +561,7 @@ class UniPC: A_p = C_inv_p if use_corrector: - print('using corrector') + #print('using corrector') C_inv = torch.linalg.inv(C) A_c = C_inv @@ -627,7 +620,7 @@ class UniPC: return x_t, model_t def multistep_uni_pc_bh_update(self, x, model_prev_list, t_prev_list, t, order, x_t=None, use_corrector=True): - print(f'using unified predictor-corrector with order {order} (solver type: B(h))') + #print(f'using unified predictor-corrector with order {order} (solver type: B(h))') ns = self.noise_schedule assert order <= len(model_prev_list) dims = x.dim() @@ -695,7 +688,7 @@ class UniPC: D1s = None if use_corrector: - print('using corrector') + #print('using corrector') # for order 1, we use a simplified version if order == 1: rhos_c = torch.tensor([0.5], device=b.device) @@ -755,8 +748,9 @@ class UniPC: t_T = self.noise_schedule.T if t_start is None else t_start device = x.device if method == 'multistep': - assert steps >= order + assert steps >= order, "UniPC order must be < sampling steps" timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) + print(f"Running UniPC Sampling with {timesteps.shape[0]} timesteps") assert timesteps.shape[0] - 1 == steps with torch.no_grad(): vec_t = timesteps[0].expand((x.shape[0])) @@ -768,6 +762,8 @@ class UniPC: x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True) if model_x is None: model_x = self.model_fn(x, vec_t) + if self.after_update is not None: + self.after_update(x, model_x) model_prev_list.append(model_x) t_prev_list.append(vec_t) for step in range(order, steps + 1): @@ -776,13 +772,15 @@ class UniPC: step_order = min(order, steps + 1 - step) else: step_order = order - print('this step order:', step_order) + #print('this step order:', step_order) if step == steps: - print('do not run corrector at the last step') + #print('do not run corrector at the last step') use_corrector = False else: use_corrector = True x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector) + if self.after_update is not None: + self.after_update(x, model_x) for i in range(order - 1): t_prev_list[i] = t_prev_list[i + 1] model_prev_list[i] = model_prev_list[i + 1] diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py index 86fa1c5be..946079ae5 100644 --- a/modules/sd_samplers_compvis.py +++ b/modules/sd_samplers_compvis.py @@ -103,16 +103,11 @@ class VanillaStableDiffusionSampler: return x, ts, cond, unconditional_conditioning - def after_sample(self, x, ts, cond, uncond, res): - if self.is_unipc: - # unipc model_fn returns (pred_x0) - # p_sample_ddim returns (x_prev, pred_x0) - res = (None, res[0]) - + def update_step(self, last_latent): if self.mask is not None: - self.last_latent = self.init_latent * self.mask + self.nmask * res[1] + self.last_latent = self.init_latent * self.mask + self.nmask * last_latent else: - self.last_latent = res[1] + self.last_latent = last_latent sd_samplers_common.store_latent(self.last_latent) @@ -120,8 +115,15 @@ class VanillaStableDiffusionSampler: state.sampling_step = self.step shared.total_tqdm.update() + def after_sample(self, x, ts, cond, uncond, res): + if not self.is_unipc: + self.update_step(res[1]) + return x, ts, cond, uncond, res + def unipc_after_update(self, x, model_x): + self.update_step(x) + def initialize(self, p): self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim if self.eta != 0.0: @@ -131,7 +133,7 @@ class VanillaStableDiffusionSampler: if hasattr(self.sampler, fieldname): setattr(self.sampler, fieldname, self.p_sample_ddim_hook) if self.is_unipc: - self.sampler.set_hooks(lambda x, t, c, u: self.before_sample(x, t, c, u), lambda x, t, c, u, r: self.after_sample(x, t, c, u, r)) + self.sampler.set_hooks(lambda x, t, c, u: self.before_sample(x, t, c, u), lambda x, t, c, u, r: self.after_sample(x, t, c, u, r), lambda x, mx: self.unipc_after_update(x, mx)) self.mask = p.mask if hasattr(p, 'mask') else None self.nmask = p.nmask if hasattr(p, 'nmask') else None From c88dcc20d495dab4be2692bdff30277112dbe416 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Fri, 10 Feb 2023 05:00:09 -0800 Subject: [PATCH 003/104] UniPC does not support img2img (for now) --- modules/processing.py | 2 +- modules/sd_samplers.py | 2 +- modules/sd_samplers_compvis.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 11e726dfe..b7cf53570 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -884,7 +884,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): shared.state.nextjob() - img2img_sampler_name = 'DDIM' # PLMS does not support img2img so we just silently switch ot DDIM + img2img_sampler_name = 'DDIM' # PLMS/UniPC does not support img2img so we just silently switch ot DDIM self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model) samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2] diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 28c2136fe..ff361f22b 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -32,7 +32,7 @@ def set_samplers(): global samplers, samplers_for_img2img hidden = set(shared.opts.hide_samplers) - hidden_img2img = set(shared.opts.hide_samplers + ['PLMS']) + hidden_img2img = set(shared.opts.hide_samplers + ['PLMS', 'UniPC']) samplers = [x for x in all_samplers if x.name not in hidden] samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img] diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py index 946079ae5..ad39ab2b3 100644 --- a/modules/sd_samplers_compvis.py +++ b/modules/sd_samplers_compvis.py @@ -139,7 +139,7 @@ class VanillaStableDiffusionSampler: self.nmask = p.nmask if hasattr(p, 'nmask') else None def adjust_steps_if_invalid(self, p, num_steps): - if ((self.config.name == 'DDIM' or self.config.name == "UniPC") and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'): + if ((self.config.name == 'DDIM') and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS') or (self.config.name == 'UniPC'): valid_step = 999 / (1000 // num_steps) if valid_step == math.floor(valid_step): return int(valid_step) + 1 From 79ffb9453f8eddbdd4e316b9d9c75812b0eea4e1 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Fri, 10 Feb 2023 05:27:05 -0800 Subject: [PATCH 004/104] Add UniPC sampler settings --- modules/models/diffusion/uni_pc/sampler.py | 5 +++-- modules/models/diffusion/uni_pc/uni_pc.py | 2 +- modules/shared.py | 5 +++++ scripts/xyz_grid.py | 7 +++++++ 4 files changed, 16 insertions(+), 3 deletions(-) diff --git a/modules/models/diffusion/uni_pc/sampler.py b/modules/models/diffusion/uni_pc/sampler.py index 219e9862c..e66a21e3b 100644 --- a/modules/models/diffusion/uni_pc/sampler.py +++ b/modules/models/diffusion/uni_pc/sampler.py @@ -3,6 +3,7 @@ import torch from .uni_pc import NoiseScheduleVP, model_wrapper, UniPC +from modules import shared class UniPCSampler(object): def __init__(self, model, **kwargs): @@ -89,7 +90,7 @@ class UniPCSampler(object): guidance_scale=unconditional_guidance_scale, ) - uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=self.before_sample, after_sample=self.after_sample, after_update=self.after_update) - x = uni_pc.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=3, lower_order_final=True) + uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=shared.opts.uni_pc_thresholding, variant=shared.opts.uni_pc_variant, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=self.before_sample, after_sample=self.after_sample, after_update=self.after_update) + x = uni_pc.sample(img, steps=S, skip_type=shared.opts.uni_pc_skip_type, method="multistep", order=shared.opts.uni_pc_order, lower_order_final=shared.opts.uni_pc_lower_order_final) return x.to(device), None diff --git a/modules/models/diffusion/uni_pc/uni_pc.py b/modules/models/diffusion/uni_pc/uni_pc.py index 31ee81a65..df63d1bcf 100644 --- a/modules/models/diffusion/uni_pc/uni_pc.py +++ b/modules/models/diffusion/uni_pc/uni_pc.py @@ -750,7 +750,7 @@ class UniPC: if method == 'multistep': assert steps >= order, "UniPC order must be < sampling steps" timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) - print(f"Running UniPC Sampling with {timesteps.shape[0]} timesteps") + print(f"Running UniPC Sampling with {timesteps.shape[0]} timesteps, order {order}") assert timesteps.shape[0] - 1 == steps with torch.no_grad(): vec_t = timesteps[0].expand((x.shape[0])) diff --git a/modules/shared.py b/modules/shared.py index 79fbf7249..342420739 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -480,6 +480,11 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}), 'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma"), + 'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "vary_coeff"]}), + 'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}), + 'uni_pc_order': OptionInfo(3, "UniPC order (must be < sampling steps)", gr.Slider, {"minimum": 1, "maximum": 150 - 1, "step": 1}), + 'uni_pc_thresholding': OptionInfo(False, "UniPC thresholding"), + 'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final"), })) options_templates.update(options_section(('postprocessing', "Postprocessing"), { diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 5982cfbaa..72421e0ca 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -126,6 +126,10 @@ def apply_styles(p: StableDiffusionProcessingTxt2Img, x: str, _): p.styles.extend(x.split(',')) +def apply_uni_pc_order(p, x, xs): + opts.data["uni_pc_order"] = min(x, p.steps - 1) + + def format_value_add_label(p, opt, x): if type(x) == float: x = round(x, 8) @@ -202,6 +206,7 @@ axis_options = [ AxisOptionImg2Img("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight")), AxisOption("VAE", str, apply_vae, cost=0.7, choices=lambda: list(sd_vae.vae_dict)), AxisOption("Styles", str, apply_styles, choices=lambda: list(shared.prompt_styles.styles)), + AxisOption("UniPC Order", int, apply_uni_pc_order, cost=0.5), ] @@ -310,9 +315,11 @@ class SharedSettingsStackHelper(object): def __enter__(self): self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers self.vae = opts.sd_vae + self.uni_pc_order = opts.uni_pc_order def __exit__(self, exc_type, exc_value, tb): opts.data["sd_vae"] = self.vae + opts.data["uni_pc_order"] = self.uni_pc_order modules.sd_models.reload_model_weights() modules.sd_vae.reload_vae_weights() From 06cb0dc92095647e4856be10b4d7dc12f5e11fa1 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Fri, 10 Feb 2023 05:36:41 -0800 Subject: [PATCH 005/104] Fix UniPC order --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/shared.py b/modules/shared.py index 342420739..670d4954e 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -482,7 +482,7 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" 'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma"), 'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "vary_coeff"]}), 'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}), - 'uni_pc_order': OptionInfo(3, "UniPC order (must be < sampling steps)", gr.Slider, {"minimum": 1, "maximum": 150 - 1, "step": 1}), + 'uni_pc_order': OptionInfo(3, "UniPC order (must be < sampling steps)", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}), 'uni_pc_thresholding': OptionInfo(False, "UniPC thresholding"), 'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final"), })) From fb274229b2c5c1a89dac0b3da28c08c92d71fd95 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Fri, 10 Feb 2023 14:30:35 -0800 Subject: [PATCH 006/104] bug fix --- modules/models/diffusion/uni_pc/sampler.py | 2 +- modules/processing.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/modules/models/diffusion/uni_pc/sampler.py b/modules/models/diffusion/uni_pc/sampler.py index e66a21e3b..0bef6eede 100644 --- a/modules/models/diffusion/uni_pc/sampler.py +++ b/modules/models/diffusion/uni_pc/sampler.py @@ -70,7 +70,7 @@ class UniPCSampler(object): # sampling C, H, W = shape size = (batch_size, C, H, W) - print(f'Data shape for UniPC sampling is {size}, eta {eta}') + print(f'Data shape for UniPC sampling is {size}') device = self.model.betas.device if x_T is None: diff --git a/modules/processing.py b/modules/processing.py index b7cf53570..0ca15491b 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -884,7 +884,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): shared.state.nextjob() - img2img_sampler_name = 'DDIM' # PLMS/UniPC does not support img2img so we just silently switch ot DDIM + img2img_sampler_name = self.sampler_name + if self.sampler_name in ['PLMS', 'UniPC']: # PLMS/UniPC do not support img2img so we just silently switch to DDIM + img2img_sampler_name = 'DDIM' self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model) samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2] From 716a69237cefb385f71105dbbf50e92d664e0f42 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Sat, 11 Feb 2023 06:18:34 -0800 Subject: [PATCH 007/104] support SD2.X models --- modules/models/diffusion/uni_pc/sampler.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/modules/models/diffusion/uni_pc/sampler.py b/modules/models/diffusion/uni_pc/sampler.py index 0bef6eede..708a9b2ba 100644 --- a/modules/models/diffusion/uni_pc/sampler.py +++ b/modules/models/diffusion/uni_pc/sampler.py @@ -80,10 +80,13 @@ class UniPCSampler(object): ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) + # SD 1.X is "noise", SD 2.X is "v" + model_type = "v" if self.model.parameterization == "v" else "noise" + model_fn = model_wrapper( lambda x, t, c: self.model.apply_model(x, t, c), ns, - model_type="noise", + model_type=model_type, guidance_type="classifier-free", #condition=conditioning, #unconditional_condition=unconditional_conditioning, From a320d157ec0221fa4e9c756327e31d881b9921ae Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Mon, 13 Feb 2023 20:26:47 -0500 Subject: [PATCH 008/104] all hiding of ui tabs --- modules/shared.py | 1 + modules/ui.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/modules/shared.py b/modules/shared.py index 79fbf7249..ded289252 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -455,6 +455,7 @@ options_templates.update(options_section(('ui', "User interface"), { "keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), "keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing ", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), "quicksettings": OptionInfo("sd_model_checkpoint", "Quicksettings list"), + "hidden_tabs": OptionInfo("", "Hidden UI tabs"), "ui_reorder": OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"), "ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order"), "localization": OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)), diff --git a/modules/ui.py b/modules/ui.py index f5df1ffeb..c99e55aba 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1568,7 +1568,10 @@ def create_ui(): parameters_copypaste.connect_paste_params_buttons() with gr.Tabs(elem_id="tabs") as tabs: + hidden_tabs = [x.lower().strip() for x in shared.opts.hidden_tabs.split(",")] for interface, label, ifid in interfaces: + if label.lower() in hidden_tabs: + continue with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid): interface.render() From 7df7e4d22796fda11629463f2fcbe859b98b1d19 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Tue, 14 Feb 2023 03:55:42 -0800 Subject: [PATCH 009/104] Allow extensions to declare paste fields for "Send to X" buttons --- modules/generation_parameters_copypaste.py | 5 +++-- modules/scripts.py | 9 +++++++++ modules/ui_common.py | 9 ++++++++- 3 files changed, 20 insertions(+), 3 deletions(-) diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index fc9e17aa2..93d955dbe 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -23,13 +23,14 @@ registered_param_bindings = [] class ParamBinding: - def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None): + def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=[]): self.paste_button = paste_button self.tabname = tabname self.source_text_component = source_text_component self.source_image_component = source_image_component self.source_tabname = source_tabname self.override_settings_component = override_settings_component + self.paste_field_names = paste_field_names def reset(): @@ -133,7 +134,7 @@ def connect_paste_params_buttons(): connect_paste(binding.paste_button, fields, binding.source_text_component, binding.override_settings_component, binding.tabname) if binding.source_tabname is not None and fields is not None: - paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else []) + paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else []) + binding.paste_field_names binding.paste_button.click( fn=lambda *x: x, inputs=[field for field, name in paste_fields[binding.source_tabname]["fields"] if name in paste_field_names], diff --git a/modules/scripts.py b/modules/scripts.py index 24056a12f..ac0785ce1 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -33,6 +33,11 @@ class Script: parsing infotext to set the value for the component; see ui.py's txt2img_paste_fields for an example """ + paste_field_names = None + """if set in ui(), this is a list of names of infotext fields; the fields will be sent through the + various "Send to " buttons when clicked + """ + def title(self): """this function should return the title of the script. This is what will be displayed in the dropdown menu.""" @@ -256,6 +261,7 @@ class ScriptRunner: self.alwayson_scripts = [] self.titles = [] self.infotext_fields = [] + self.paste_field_names = [] def initialize_scripts(self, is_img2img): from modules import scripts_auto_postprocessing @@ -304,6 +310,9 @@ class ScriptRunner: if script.infotext_fields is not None: self.infotext_fields += script.infotext_fields + if script.paste_field_names is not None: + self.paste_field_names += script.paste_field_names + inputs += controls inputs_alwayson += [script.alwayson for _ in controls] script.args_to = len(inputs) diff --git a/modules/ui_common.py b/modules/ui_common.py index fd047f318..a12433d23 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -198,9 +198,16 @@ Requested path was: {f} html_info = gr.HTML(elem_id=f'html_info_{tabname}') html_log = gr.HTML(elem_id=f'html_log_{tabname}') + paste_field_names = [] + if tabname == "txt2img": + paste_field_names = modules.scripts.scripts_txt2img.paste_field_names + elif tabname == "img2img": + paste_field_names = modules.scripts.scripts_img2img.paste_field_names + for paste_tabname, paste_button in buttons.items(): parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding( - paste_button=paste_button, tabname=paste_tabname, source_tabname="txt2img" if tabname == "txt2img" else None, source_image_component=result_gallery + paste_button=paste_button, tabname=paste_tabname, source_tabname="txt2img" if tabname == "txt2img" else None, source_image_component=result_gallery, + paste_field_names=paste_field_names )) return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log From 83829471decbde64d335eb510d4a5670baf68773 Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Sun, 19 Feb 2023 09:21:44 -0500 Subject: [PATCH 010/104] make ui as multiselect instead of string list --- modules/shared.py | 3 ++- modules/ui.py | 7 +++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/modules/shared.py b/modules/shared.py index 1a1abeb25..a7c5f58e1 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -305,6 +305,7 @@ def list_samplers(): hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config} +tab_names = [] options_templates = {} @@ -460,7 +461,7 @@ options_templates.update(options_section(('ui', "User interface"), { "keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), "keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing ", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), "quicksettings": OptionInfo("sd_model_checkpoint", "Quicksettings list"), - "hidden_tabs": OptionInfo("", "Hidden UI tabs"), + "hidden_tabs": OptionInfo([], "Hidden UI tabs (requires restart)", ui_components.DropdownMulti, lambda: {"choices": [x for x in tab_names]}), "ui_reorder": OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"), "ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order"), "localization": OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)), diff --git a/modules/ui.py b/modules/ui.py index a4ecd41b2..5ac249b2b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1563,6 +1563,10 @@ def create_ui(): extensions_interface = ui_extensions.create_ui() interfaces += [(extensions_interface, "Extensions", "extensions")] + shared.tab_names = [] + for _interface, label, _ifid in interfaces: + shared.tab_names.append(label) + with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo: with gr.Row(elem_id="quicksettings", variant="compact"): for i, k, item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])): @@ -1572,9 +1576,8 @@ def create_ui(): parameters_copypaste.connect_paste_params_buttons() with gr.Tabs(elem_id="tabs") as tabs: - hidden_tabs = [x.lower().strip() for x in shared.opts.hidden_tabs.split(",")] for interface, label, ifid in interfaces: - if label.lower() in hidden_tabs: + if label in shared.opts.hidden_tabs: continue with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid): interface.render() From bab972ff8ab6be1132ca2b58a2c4fadac0a0685d Mon Sep 17 00:00:00 2001 From: EllangoK Date: Mon, 20 Feb 2023 10:16:55 -0500 Subject: [PATCH 011/104] fixes newline being detected as its own entry --- scripts/xyz_grid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 53511b121..d0ff5cb86 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -418,7 +418,7 @@ class Script(scripts.Script): if opt.label == 'Nothing': return [0] - valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals)))] + valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals))) if x] if opt.type == int: valslist_ext = [] From b0f2653541219486c8b8cbdbf8ce80caf3f62ef8 Mon Sep 17 00:00:00 2001 From: xSinStarx <98231899+xSinStarx@users.noreply.github.com> Date: Mon, 20 Feb 2023 12:39:38 -0800 Subject: [PATCH 012/104] Fixes img2img Negative Token Counter The img2img negative token counter is counting the txt2img negative prompt. --- modules/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/ui.py b/modules/ui.py index 0516c6436..a37b17396 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -939,7 +939,7 @@ def create_ui(): ) token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter]) - negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter]) + negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[img2img_negative_prompt, steps], outputs=[negative_token_counter]) ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery) From 32a4c8d961df3da4534c98fd0573d854cff1bb91 Mon Sep 17 00:00:00 2001 From: Kilvoctu Date: Mon, 20 Feb 2023 15:14:06 -0600 Subject: [PATCH 013/104] use emojis for extra network buttons MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 🔄 for refresh ❌ for close --- modules/ui_extra_networks.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 71f1d81f2..8786fde6b 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -13,6 +13,10 @@ from modules.generation_parameters_copypaste import image_from_url_text extra_pages = [] allowed_dirs = set() +# Using constants for these since the variation selector isn't visible. +# Important that they exactly match script.js for tooltip to work. +refresh_symbol = '\U0001f504' # 🔄 +close_symbol = '\U0000274C' # ❌ def register_page(page): """registers extra networks page for the UI; recommend doing it in on_before_ui() callback for extensions""" @@ -182,8 +186,8 @@ def create_ui(container, button, tabname): ui.pages.append(page_elem) filter = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False) - button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh") - button_close = gr.Button('Close', elem_id=tabname+"_extra_close") + button_refresh = gr.Button(refresh_symbol, elem_id=tabname+"_extra_refresh") + button_close = gr.Button(close_symbol, elem_id=tabname+"_extra_close") ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False) From a2d635ad135241a0a40f67f7e1638c9c8a4ded04 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Wed, 22 Feb 2023 01:52:53 -0800 Subject: [PATCH 014/104] Add before_process_batch script callback --- modules/processing.py | 3 +++ modules/scripts.py | 23 +++++++++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/modules/processing.py b/modules/processing.py index 2009d3bf8..187e98fde 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -597,6 +597,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size] subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size] + if p.scripts is not None: + p.scripts.before_process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds) + if len(prompts) == 0: break diff --git a/modules/scripts.py b/modules/scripts.py index 24056a12f..e6a505b39 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -80,6 +80,20 @@ class Script: pass + def before_process_batch(self, p, *args, **kwargs): + """ + Called before extra networks are parsed from the prompt, so you can add + new extra network keywords to the prompt with this callback. + + **kwargs will have those items: + - batch_number - index of current batch, from 0 to number of batches-1 + - prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things + - seeds - list of seeds for current batch + - subseeds - list of subseeds for current batch + """ + + pass + def process_batch(self, p, *args, **kwargs): """ Same as process(), but called for every batch. @@ -388,6 +402,15 @@ class ScriptRunner: print(f"Error running process: {script.filename}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) + def before_process_batch(self, p, **kwargs): + for script in self.alwayson_scripts: + try: + script_args = p.script_args[script.args_from:script.args_to] + script.before_process_batch(p, *script_args, **kwargs) + except Exception: + print(f"Error running before_process_batch: {script.filename}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + def process_batch(self, p, **kwargs): for script in self.alwayson_scripts: try: From 2c58d373dd408153dc126f6eba1525d32fbf92bb Mon Sep 17 00:00:00 2001 From: 112292454 <92578848+112292454@users.noreply.github.com> Date: Wed, 22 Feb 2023 21:40:42 +0800 Subject: [PATCH 015/104] Update prompt_matrix.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit this file last commit fixed common situation when using both prompts matrix and high-res。 but if we just open matrix option,but not use ‘|’,we will only get one pic,and `processed.images[0].width, processed.images[1].height` will cause a index out of bounds exception --- scripts/prompt_matrix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/prompt_matrix.py b/scripts/prompt_matrix.py index b1c486d44..7790ac38f 100644 --- a/scripts/prompt_matrix.py +++ b/scripts/prompt_matrix.py @@ -100,7 +100,7 @@ class Script(scripts.Script): processed = process_images(p) grid = images.image_grid(processed.images, p.batch_size, rows=1 << ((len(prompt_matrix_parts) - 1) // 2)) - grid = images.draw_prompt_matrix(grid, processed.images[0].width, processed.images[1].height, prompt_matrix_parts, margin_size) + grid = images.draw_prompt_matrix(grid, p.hr_upscale_to_x, p.hr_upscale_to_y,, prompt_matrix_parts, margin_size) processed.images.insert(0, grid) processed.index_of_first_image = 1 processed.infotexts.insert(0, processed.infotexts[0]) From 2fa91cbee65429e611861df1c32657c941f4acaf Mon Sep 17 00:00:00 2001 From: 112292454 <92578848+112292454@users.noreply.github.com> Date: Thu, 23 Feb 2023 01:55:07 +0800 Subject: [PATCH 016/104] Update prompt_matrix.py 1 --- scripts/prompt_matrix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/prompt_matrix.py b/scripts/prompt_matrix.py index 7790ac38f..e9b115170 100644 --- a/scripts/prompt_matrix.py +++ b/scripts/prompt_matrix.py @@ -100,7 +100,7 @@ class Script(scripts.Script): processed = process_images(p) grid = images.image_grid(processed.images, p.batch_size, rows=1 << ((len(prompt_matrix_parts) - 1) // 2)) - grid = images.draw_prompt_matrix(grid, p.hr_upscale_to_x, p.hr_upscale_to_y,, prompt_matrix_parts, margin_size) + grid = images.draw_prompt_matrix(grid, processed.images[0].width, processed.images[0].height, prompt_matrix_parts, margin_size) processed.images.insert(0, grid) processed.index_of_first_image = 1 processed.infotexts.insert(0, processed.infotexts[0]) From 6825de7bc811d777ff0d462e5668fa4fba73a889 Mon Sep 17 00:00:00 2001 From: Thomas Young <35073576+DrakeRichards@users.noreply.github.com> Date: Wed, 22 Feb 2023 15:31:49 -0600 Subject: [PATCH 017/104] Added results selector This causes the querySelectorAll function to only select images in a results div, ignoring images that might be in an extension's gallery. --- javascript/notification.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/javascript/notification.js b/javascript/notification.js index 040a3afac..5ae6df24d 100644 --- a/javascript/notification.js +++ b/javascript/notification.js @@ -15,7 +15,7 @@ onUiUpdate(function(){ } } - const galleryPreviews = gradioApp().querySelectorAll('div[id^="tab_"][style*="display: block"] img.h-full.w-full.overflow-hidden'); + const galleryPreviews = gradioApp().querySelectorAll('div[id^="tab_"][style*="display: block"] div[id$="_results"] img.h-full.w-full.overflow-hidden'); if (galleryPreviews == null) return; From b90cad7f3136bbe04efeee2a00e95d0cc6ce1a4a Mon Sep 17 00:00:00 2001 From: "fkunn1326@users.noreply.github.com" Date: Thu, 23 Feb 2023 03:29:22 +0000 Subject: [PATCH 018/104] Add .mjs support for extensions --- modules/ui.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/modules/ui.py b/modules/ui.py index 0516c6436..2509ce2d0 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1754,6 +1754,9 @@ def reload_javascript(): for script in modules.scripts.list_scripts("javascript", ".js"): head += f'\n' + for script in modules.scripts.list_scripts("javascript", ".mjs"): + head += f'\n' + head += f'\n' def template_response(*args, **kwargs): From ac4c7f05cd38dfa99cf64f7ddb9b1656e70a13c5 Mon Sep 17 00:00:00 2001 From: Tpinion Date: Fri, 24 Feb 2023 00:42:29 +0800 Subject: [PATCH 019/104] Filter out temporary files that will be generated if the download fails. --- modules/codeformer_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py index 01fb7bd84..8d84bbc90 100644 --- a/modules/codeformer_model.py +++ b/modules/codeformer_model.py @@ -55,7 +55,7 @@ def setup_model(dirname): if self.net is not None and self.face_helper is not None: self.net.to(devices.device_codeformer) return self.net, self.face_helper - model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir, download_name='codeformer-v0.1.0.pth') + model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir, download_name='codeformer-v0.1.0.pth', ext_filter=['.pth']) if len(model_paths) != 0: ckpt_path = model_paths[0] else: From 327186b484c344598522a989a4b4859d6b90fb04 Mon Sep 17 00:00:00 2001 From: laksjdjf Date: Fri, 24 Feb 2023 14:03:46 +0900 Subject: [PATCH 020/104] Update script_callbacks.py --- modules/script_callbacks.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index edd0e2a72..c5bb3e71e 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -29,7 +29,7 @@ class ImageSaveParams: class CFGDenoiserParams: - def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps): + def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, tensor, uncond): self.x = x """Latent image representation in the process of being denoised""" @@ -44,6 +44,12 @@ class CFGDenoiserParams: self.total_sampling_steps = total_sampling_steps """Total number of sampling steps planned""" + + self.tensor = tensor + """ Encoder hidden states of condtioning""" + + self.uncond = uncond + """ Encoder hidden states of unconditioning""" class CFGDenoisedParams: From 9a1435946ce7cc7f2cdaa0a312e1c0d296b8c276 Mon Sep 17 00:00:00 2001 From: laksjdjf Date: Fri, 24 Feb 2023 14:04:23 +0900 Subject: [PATCH 021/104] Update sd_samplers_kdiffusion.py --- modules/sd_samplers_kdiffusion.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 528f513fe..ea974be04 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -101,11 +101,13 @@ class CFGDenoiser(torch.nn.Module): sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma]) image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond] + [torch.zeros_like(self.init_latent)]) - denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps) + denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond) cfg_denoiser_callback(denoiser_params) x_in = denoiser_params.x image_cond_in = denoiser_params.image_cond sigma_in = denoiser_params.sigma + tensor = denoiser_params.tensor + uncond = denoiser_params.uncond if tensor.shape[1] == uncond.shape[1]: if not is_edit_model: From 534cf60afb547de72891e1e87b59a1433aadeee3 Mon Sep 17 00:00:00 2001 From: laksjdjf Date: Fri, 24 Feb 2023 14:26:55 +0900 Subject: [PATCH 022/104] Update script_callbacks.py --- modules/script_callbacks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index c5bb3e71e..d17031355 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -46,7 +46,7 @@ class CFGDenoiserParams: """Total number of sampling steps planned""" self.tensor = tensor - """ Encoder hidden states of condtioning""" + """ Encoder hidden states of conditioning""" self.uncond = uncond """ Encoder hidden states of unconditioning""" From b15bc73c99e6fbbeffdbdbeab39ba30276021d4b Mon Sep 17 00:00:00 2001 From: Brad Smith Date: Fri, 24 Feb 2023 14:22:58 -0500 Subject: [PATCH 023/104] sort upscalers by name --- modules/modelloader.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/modules/modelloader.py b/modules/modelloader.py index fc3f6249f..a7ac338c2 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -6,7 +6,7 @@ from urllib.parse import urlparse from basicsr.utils.download_util import load_file_from_url from modules import shared -from modules.upscaler import Upscaler +from modules.upscaler import Upscaler, UpscalerNone from modules.paths import script_path, models_path @@ -169,4 +169,8 @@ def load_upscalers(): scaler = cls(commandline_options.get(cmd_name, None)) datas += scaler.scalers - shared.sd_upscalers = datas + shared.sd_upscalers = sorted( + datas, + # Special case for UpscalerNone keeps it at the beginning of the list. + key=lambda x: x.name if not isinstance(x.scaler, UpscalerNone) else "" + ) From aa108bd02a8282e8213fa6c5967e3c47e49bb43f Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Fri, 24 Feb 2023 20:57:18 -0700 Subject: [PATCH 024/104] Add lossless webp option --- modules/images.py | 2 +- modules/shared.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/modules/images.py b/modules/images.py index 5b80c23e1..7df2b08c7 100644 --- a/modules/images.py +++ b/modules/images.py @@ -556,7 +556,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i elif image_to_save.mode == 'I;16': image_to_save = image_to_save.point(lambda p: p * 0.0038910505836576).convert("RGB" if extension.lower() == ".webp" else "L") - image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality) + image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality, lossless=opts.webp_lossless) if opts.enable_pnginfo and info is not None: exif_bytes = piexif.dump({ diff --git a/modules/shared.py b/modules/shared.py index 805f9cc19..511019880 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -327,6 +327,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids" "save_images_before_highres_fix": OptionInfo(False, "Save a copy of image before applying highres fix."), "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), "jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}), + "webp_lossless": OptionInfo(False, "Use lossless compression for webp images"), "export_for_4chan": OptionInfo(True, "If the saved image file size is above the limit, or its either width or height are above the limit, save a downscaled copy as JPG"), "img_downscale_threshold": OptionInfo(4.0, "File size limit for the above option, MB", gr.Number), "target_side_length": OptionInfo(4000, "Width/height limit for the above option, in pixels", gr.Number), From ed43a822b2df601fd1d6d2a3b97ae5924c06ca98 Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Sat, 25 Feb 2023 12:56:03 -0500 Subject: [PATCH 025/104] fix progressbar --- javascript/progressbar.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/javascript/progressbar.js b/javascript/progressbar.js index ff6d757ba..9ccc9da46 100644 --- a/javascript/progressbar.js +++ b/javascript/progressbar.js @@ -139,7 +139,7 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre var divProgress = document.createElement('div') divProgress.className='progressDiv' - divProgress.style.display = opts.show_progressbar ? "" : "none" + divProgress.style.display = opts.show_progressbar ? "block" : "none" var divInner = document.createElement('div') divInner.className='progress' From 6d92d95a33e46aea7a7b8c136a76a621d7fc4f52 Mon Sep 17 00:00:00 2001 From: Adam Huganir Date: Sat, 25 Feb 2023 19:15:06 +0000 Subject: [PATCH 026/104] git 3.1.30 api change --- modules/extensions.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/modules/extensions.py b/modules/extensions.py index 3eef9eaf6..ed4b58fe3 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -66,7 +66,7 @@ class Extension: def check_updates(self): repo = git.Repo(self.path) - for fetch in repo.remote().fetch("--dry-run"): + for fetch in repo.remote().fetch(dry_run=True): if fetch.flags != fetch.HEAD_UPTODATE: self.can_update = True self.status = "behind" @@ -79,8 +79,8 @@ class Extension: repo = git.Repo(self.path) # Fix: `error: Your local changes to the following files would be overwritten by merge`, # because WSL2 Docker set 755 file permissions instead of 644, this results to the error. - repo.git.fetch('--all') - repo.git.reset('--hard', 'origin') + repo.git.fetch(all=True) + repo.git.reset('origin', hard=True) def list_extensions(): From 3c6459154fb115ea7cf1a0c5f3f0761a192dfea3 Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Mon, 27 Feb 2023 17:28:04 -0500 Subject: [PATCH 027/104] add check for resulting image size --- modules/shared.py | 1 + scripts/xyz_grid.py | 6 ++++++ 2 files changed, 7 insertions(+) diff --git a/modules/shared.py b/modules/shared.py index 805f9cc19..ec08b7bec 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -330,6 +330,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids" "export_for_4chan": OptionInfo(True, "If the saved image file size is above the limit, or its either width or height are above the limit, save a downscaled copy as JPG"), "img_downscale_threshold": OptionInfo(4.0, "File size limit for the above option, MB", gr.Number), "target_side_length": OptionInfo(4000, "Width/height limit for the above option, in pixels", gr.Number), + "img_max_size_mp": OptionInfo(200, "Maximum image size, in megapixels", gr.Number), "use_original_name_batch": OptionInfo(True, "Use original name for output filename during batch process in extras tab"), "use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"), diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 53511b121..1ba954ac6 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -484,6 +484,12 @@ class Script(scripts.Script): z_opt = self.current_axis_options[z_type] zs = process_axis(z_opt, z_values) + # this could be moved to common code, but unlikely to be ever triggered anywhere else + Image.MAX_IMAGE_PIXELS = opts.img_max_size_mp * 1.1 # allow 10% overhead for margins and legend + grid_mp = round(len(xs) * len(ys) * len(zs) * p.width * p.height / 1000000) + if grid_mp > opts.img_max_size_mp: + return Processed(p, [], p.seed, info=f'Error: Resulting grid would be too large ({grid_mp} MPixels) (max configured size is {opts.img_max_size_mp} MPixels)') + def fix_axis_seeds(axis_opt, axis_list): if axis_opt.label in ['Seed', 'Var. seed']: return [int(random.randrange(4294967294)) if val is None or val == '' or val == -1 else val for val in axis_list] From 3b6de96467b0ee6d59f3aaf4bafd1467633e14c6 Mon Sep 17 00:00:00 2001 From: Vespinian Date: Sun, 26 Feb 2023 19:17:58 -0500 Subject: [PATCH 028/104] Added alwayson_script_name and alwayson_script_args to api Added 2 additional possible entries in the api request: alwayson_script_name, a string list, and, alwayson_script_args, a list of list containing the args of each script. This allows us to send args to always on script and keep backwards compatibility with old script_name and script_arg api params --- modules/api/api.py | 111 +++++++++++++++++++++++++++++++++++++----- modules/api/models.py | 4 +- 2 files changed, 100 insertions(+), 15 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index 5a9ac5f1a..a1cdebb8d 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -163,20 +163,26 @@ class Api: raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"}) - def get_script(self, script_name, script_runner): - if script_name is None: + def get_selectable_script(self, script_name, script_runner): + if script_name is None or script_name == "": return None, None - if not script_runner.scripts: - script_runner.initialize_scripts(False) - ui.create_ui() - script_idx = script_name_to_index(script_name, script_runner.selectable_scripts) script = script_runner.selectable_scripts[script_idx] return script, script_idx + def get_script(self, script_name, script_runner): + for script in script_runner.scripts: + if script_name.lower() == script.title().lower(): + return script + return None + def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): - script, script_idx = self.get_script(txt2imgreq.script_name, scripts.scripts_txt2img) + script_runner = scripts.scripts_txt2img + if not script_runner.scripts: + script_runner.initialize_scripts(False) + ui.create_ui() + api_selectable_scripts, api_selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner) populate = txt2imgreq.copy(update={ # Override __init__ params "sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index), @@ -184,22 +190,59 @@ class Api: "do_not_save_grid": True } ) + if populate.sampler_name: populate.sampler_index = None # prevent a warning later on args = vars(populate) args.pop('script_name', None) + args.pop('script_args', None) # will refeed them later with script_args + args.pop('alwayson_script_name', None) + args.pop('alwayson_script_args', None) + + #find max idx from the scripts in runner and generate a none array to init script_args + last_arg_index = 1 + for script in script_runner.scripts: + if last_arg_index < script.args_to: + last_arg_index = script.args_to + # None everywhere exepct position 0 to initialize script args + script_args = [None]*last_arg_index + # position 0 in script_arg is the idx+1 of the selectable script that is going to be run + if api_selectable_scripts: + script_args[api_selectable_scripts.args_from:api_selectable_scripts.args_to] = txt2imgreq.script_args + script_args[0] = api_selectable_script_idx + 1 + else: + # if 0 then none + script_args[0] = 0 + + # Now check for always on scripts + if len(txt2imgreq.alwayson_script_name) > 0: + # always on script with no arg should always run, but if you include their name in the api request, send an empty list for there args + if len(txt2imgreq.alwayson_script_name) != len(txt2imgreq.alwayson_script_args): + raise HTTPException(status_code=422, detail=f"Number of script names and number of script arg lists doesn't match") + + for alwayson_script_name, alwayson_script_args in zip(txt2imgreq.alwayson_script_name, txt2imgreq.alwayson_script_args): + alwayson_script = self.get_script(alwayson_script_name, script_runner) + if alwayson_script == None: + raise HTTPException(status_code=422, detail=f"always on script {alwayson_script_name} not found") + # Selectable script in always on script param check + if alwayson_script.alwayson == False: + raise HTTPException(status_code=422, detail=f"Cannot have a selectable script in the always on scripts params") + if alwayson_script_args != []: + script_args[alwayson_script.args_from:alwayson_script.args_to] = alwayson_script_args with self.queue_lock: p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args) + p.scripts = script_runner shared.state.begin() - if script is not None: + if api_selectable_scripts != None: + p.script_args = script_args p.outpath_grids = opts.outdir_txt2img_grids p.outpath_samples = opts.outdir_txt2img_samples - p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args processed = scripts.scripts_txt2img.run(p, *p.script_args) else: + p.script_args = tuple(script_args) processed = process_images(p) shared.state.end() @@ -212,12 +255,16 @@ class Api: if init_images is None: raise HTTPException(status_code=404, detail="Init image not found") - script, script_idx = self.get_script(img2imgreq.script_name, scripts.scripts_img2img) - mask = img2imgreq.mask if mask: mask = decode_base64_to_image(mask) + script_runner = scripts.scripts_img2img + if not script_runner.scripts: + script_runner.initialize_scripts(True) + ui.create_ui() + api_selectable_scripts, api_selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner) + populate = img2imgreq.copy(update={ # Override __init__ params "sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index), "do_not_save_samples": True, @@ -225,24 +272,62 @@ class Api: "mask": mask } ) + if populate.sampler_name: populate.sampler_index = None # prevent a warning later on args = vars(populate) args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine. args.pop('script_name', None) + args.pop('script_args', None) # will refeed them later with script_args + args.pop('alwayson_script_name', None) + args.pop('alwayson_script_args', None) + + #find max idx from the scripts in runner and generate a none array to init script_args + last_arg_index = 1 + for script in script_runner.scripts: + if last_arg_index < script.args_to: + last_arg_index = script.args_to + # None everywhere exepct position 0 to initialize script args + script_args = [None]*last_arg_index + # position 0 in script_arg is the idx+1 of the selectable script that is going to be run + if api_selectable_scripts: + script_args[api_selectable_scripts.args_from:api_selectable_scripts.args_to] = img2imgreq.script_args + script_args[0] = api_selectable_script_idx + 1 + else: + # if 0 then none + script_args[0] = 0 + + # Now check for always on scripts + if len(img2imgreq.alwayson_script_name) > 0: + # always on script with no arg should always run, but if you include their name in the api request, send an empty list for there args + if len(img2imgreq.alwayson_script_name) != len(img2imgreq.alwayson_script_args): + raise HTTPException(status_code=422, detail=f"Number of script names and number of script arg lists doesn't match") + + for alwayson_script_name, alwayson_script_args in zip(img2imgreq.alwayson_script_name, img2imgreq.alwayson_script_args): + alwayson_script = self.get_script(alwayson_script_name, script_runner) + if alwayson_script == None: + raise HTTPException(status_code=422, detail=f"always on script {alwayson_script_name} not found") + # Selectable script in always on script param check + if alwayson_script.alwayson == False: + raise HTTPException(status_code=422, detail=f"Cannot have a selectable script in the always on scripts params") + if alwayson_script_args != []: + script_args[alwayson_script.args_from:alwayson_script.args_to] = alwayson_script_args + with self.queue_lock: p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args) p.init_images = [decode_base64_to_image(x) for x in init_images] + p.scripts = script_runner shared.state.begin() - if script is not None: + if api_selectable_scripts != None: + p.script_args = script_args p.outpath_grids = opts.outdir_img2img_grids p.outpath_samples = opts.outdir_img2img_samples - p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args processed = scripts.scripts_img2img.run(p, *p.script_args) else: + p.script_args = tuple(script_args) processed = process_images(p) shared.state.end() diff --git a/modules/api/models.py b/modules/api/models.py index cba43d3b1..86c701780 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -100,13 +100,13 @@ class PydanticModelGenerator: StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator( "StableDiffusionProcessingTxt2Img", StableDiffusionProcessingTxt2Img, - [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}] + [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}, {"key": "alwayson_script_name", "type": list, "default": []}, {"key": "alwayson_script_args", "type": list, "default": []}] ).generate_model() StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator( "StableDiffusionProcessingImg2Img", StableDiffusionProcessingImg2Img, - [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}] + [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}, {"key": "alwayson_script_name", "type": list, "default": []}, {"key": "alwayson_script_args", "type": list, "default": []}] ).generate_model() class TextToImageResponse(BaseModel): From a39c4cf766a9b0f18972fc52bcf5189173f434c6 Mon Sep 17 00:00:00 2001 From: Vespinian Date: Mon, 27 Feb 2023 23:27:33 -0500 Subject: [PATCH 029/104] small refactor of api.py --- modules/api/api.py | 125 ++++++++++++++++++--------------------------- 1 file changed, 51 insertions(+), 74 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index a1cdebb8d..d4c0c1521 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -172,17 +172,53 @@ class Api: return script, script_idx def get_script(self, script_name, script_runner): + if script_name is None or script_name == "": + return None, None + + script_idx = script_name_to_index(script_name, script_runner.scripts) + return script_runner.scripts[script_idx] + + def init_script_args(self, request, selectable_scripts, selectable_idx, script_runner): + #find max idx from the scripts in runner and generate a none array to init script_args + last_arg_index = 1 for script in script_runner.scripts: - if script_name.lower() == script.title().lower(): - return script - return None + if last_arg_index < script.args_to: + last_arg_index = script.args_to + # None everywhere except position 0 to initialize script args + script_args = [None]*last_arg_index + # position 0 in script_arg is the idx+1 of the selectable script that is going to be run when using scripts.scripts_*2img.run() + if selectable_scripts: + script_args[selectable_scripts.args_from:selectable_scripts.args_to] = request.script_args + script_args[0] = selectable_idx + 1 + else: + # if 0 then none + script_args[0] = 0 + + # Now check for always on scripts + if request.alwayson_script_name and (len(request.alwayson_script_name) > 0): + # always on script with no arg should always run, but if you include their name in the api request, send an empty list for there args + if not request.alwayson_script_args: + raise HTTPException(status_code=422, detail=f"Script {request.alwayson_script_name} has no arg list") + if len(request.alwayson_script_name) != len(request.alwayson_script_args): + raise HTTPException(status_code=422, detail=f"Number of script names and number of script arg lists doesn't match") + + for alwayson_script_name, alwayson_script_args in zip(request.alwayson_script_name, request.alwayson_script_args): + alwayson_script = self.get_script(alwayson_script_name, script_runner) + if alwayson_script == None: + raise HTTPException(status_code=422, detail=f"always on script {alwayson_script_name} not found") + # Selectable script in always on script param check + if alwayson_script.alwayson == False: + raise HTTPException(status_code=422, detail=f"Cannot have a selectable script in the always on scripts params") + if alwayson_script_args != []: + script_args[alwayson_script.args_from:alwayson_script.args_to] = alwayson_script_args + return script_args def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): script_runner = scripts.scripts_txt2img if not script_runner.scripts: script_runner.initialize_scripts(False) ui.create_ui() - api_selectable_scripts, api_selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner) + selectable_scripts, selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner) populate = txt2imgreq.copy(update={ # Override __init__ params "sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index), @@ -196,53 +232,24 @@ class Api: args = vars(populate) args.pop('script_name', None) - args.pop('script_args', None) # will refeed them later with script_args + args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them args.pop('alwayson_script_name', None) args.pop('alwayson_script_args', None) - #find max idx from the scripts in runner and generate a none array to init script_args - last_arg_index = 1 - for script in script_runner.scripts: - if last_arg_index < script.args_to: - last_arg_index = script.args_to - # None everywhere exepct position 0 to initialize script args - script_args = [None]*last_arg_index - # position 0 in script_arg is the idx+1 of the selectable script that is going to be run - if api_selectable_scripts: - script_args[api_selectable_scripts.args_from:api_selectable_scripts.args_to] = txt2imgreq.script_args - script_args[0] = api_selectable_script_idx + 1 - else: - # if 0 then none - script_args[0] = 0 - - # Now check for always on scripts - if len(txt2imgreq.alwayson_script_name) > 0: - # always on script with no arg should always run, but if you include their name in the api request, send an empty list for there args - if len(txt2imgreq.alwayson_script_name) != len(txt2imgreq.alwayson_script_args): - raise HTTPException(status_code=422, detail=f"Number of script names and number of script arg lists doesn't match") - - for alwayson_script_name, alwayson_script_args in zip(txt2imgreq.alwayson_script_name, txt2imgreq.alwayson_script_args): - alwayson_script = self.get_script(alwayson_script_name, script_runner) - if alwayson_script == None: - raise HTTPException(status_code=422, detail=f"always on script {alwayson_script_name} not found") - # Selectable script in always on script param check - if alwayson_script.alwayson == False: - raise HTTPException(status_code=422, detail=f"Cannot have a selectable script in the always on scripts params") - if alwayson_script_args != []: - script_args[alwayson_script.args_from:alwayson_script.args_to] = alwayson_script_args + script_args = self.init_script_args(txt2imgreq, selectable_scripts, selectable_script_idx, script_runner) with self.queue_lock: p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args) p.scripts = script_runner shared.state.begin() - if api_selectable_scripts != None: + if selectable_scripts != None: p.script_args = script_args p.outpath_grids = opts.outdir_txt2img_grids p.outpath_samples = opts.outdir_txt2img_samples - processed = scripts.scripts_txt2img.run(p, *p.script_args) + processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here else: - p.script_args = tuple(script_args) + p.script_args = tuple(script_args) # Need to pass args as tuple here processed = process_images(p) shared.state.end() @@ -263,7 +270,7 @@ class Api: if not script_runner.scripts: script_runner.initialize_scripts(True) ui.create_ui() - api_selectable_scripts, api_selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner) + selectable_scripts, selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner) populate = img2imgreq.copy(update={ # Override __init__ params "sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index), @@ -279,41 +286,11 @@ class Api: args = vars(populate) args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine. args.pop('script_name', None) - args.pop('script_args', None) # will refeed them later with script_args + args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them args.pop('alwayson_script_name', None) args.pop('alwayson_script_args', None) - #find max idx from the scripts in runner and generate a none array to init script_args - last_arg_index = 1 - for script in script_runner.scripts: - if last_arg_index < script.args_to: - last_arg_index = script.args_to - # None everywhere exepct position 0 to initialize script args - script_args = [None]*last_arg_index - # position 0 in script_arg is the idx+1 of the selectable script that is going to be run - if api_selectable_scripts: - script_args[api_selectable_scripts.args_from:api_selectable_scripts.args_to] = img2imgreq.script_args - script_args[0] = api_selectable_script_idx + 1 - else: - # if 0 then none - script_args[0] = 0 - - # Now check for always on scripts - if len(img2imgreq.alwayson_script_name) > 0: - # always on script with no arg should always run, but if you include their name in the api request, send an empty list for there args - if len(img2imgreq.alwayson_script_name) != len(img2imgreq.alwayson_script_args): - raise HTTPException(status_code=422, detail=f"Number of script names and number of script arg lists doesn't match") - - for alwayson_script_name, alwayson_script_args in zip(img2imgreq.alwayson_script_name, img2imgreq.alwayson_script_args): - alwayson_script = self.get_script(alwayson_script_name, script_runner) - if alwayson_script == None: - raise HTTPException(status_code=422, detail=f"always on script {alwayson_script_name} not found") - # Selectable script in always on script param check - if alwayson_script.alwayson == False: - raise HTTPException(status_code=422, detail=f"Cannot have a selectable script in the always on scripts params") - if alwayson_script_args != []: - script_args[alwayson_script.args_from:alwayson_script.args_to] = alwayson_script_args - + script_args = self.init_script_args(img2imgreq, selectable_scripts, selectable_script_idx, script_runner) with self.queue_lock: p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args) @@ -321,13 +298,13 @@ class Api: p.scripts = script_runner shared.state.begin() - if api_selectable_scripts != None: + if selectable_scripts != None: p.script_args = script_args p.outpath_grids = opts.outdir_img2img_grids p.outpath_samples = opts.outdir_img2img_samples - processed = scripts.scripts_img2img.run(p, *p.script_args) + processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here else: - p.script_args = tuple(script_args) + p.script_args = tuple(script_args) # Need to pass args as tuple here processed = process_images(p) shared.state.end() From c6c2a59333c77dffff49a748bfed8c54af6e2abd Mon Sep 17 00:00:00 2001 From: Vespinian Date: Mon, 27 Feb 2023 23:45:59 -0500 Subject: [PATCH 030/104] comment clarification --- modules/api/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/api/api.py b/modules/api/api.py index d4c0c1521..248922d29 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -191,7 +191,7 @@ class Api: script_args[selectable_scripts.args_from:selectable_scripts.args_to] = request.script_args script_args[0] = selectable_idx + 1 else: - # if 0 then none + # when [0] = 0 no selectable script to run script_args[0] = 0 # Now check for always on scripts From 1e30e4d9ebd9c36ccee43ec0e61c6ab490171614 Mon Sep 17 00:00:00 2001 From: Ju1-js <40339350+Ju1-js@users.noreply.github.com> Date: Tue, 28 Feb 2023 15:55:12 -0800 Subject: [PATCH 031/104] Gradio auth logic fix - Handle empty/newlines When the massive one-liner was split into multiple lines, it lost the ability to handle newlines. This removes empty strings & newline characters from the logins. It also closes the file so it's more robust if the garbage collection function is ever changed. --- webui.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/webui.py b/webui.py index 9e8b486aa..5e925fa77 100644 --- a/webui.py +++ b/webui.py @@ -209,11 +209,12 @@ def webui(): gradio_auth_creds = [] if cmd_opts.gradio_auth: - gradio_auth_creds += cmd_opts.gradio_auth.strip('"').replace('\n', '').split(',') + gradio_auth_creds += [x.strip() for x in cmd_opts.gradio_auth.strip('"').replace('/n', '').split(',') if x.strip()] if cmd_opts.gradio_auth_path: with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file: for line in file.readlines(): - gradio_auth_creds += [x.strip() for x in line.split(',')] + gradio_auth_creds += [x.strip() for x in line.split(',') if x.strip()] + file.close() app, local_url, share_url = shared.demo.launch( share=cmd_opts.share, From 7990ed92be7f34e609b441252ff97ae1504b0a3f Mon Sep 17 00:00:00 2001 From: Ju1-js <40339350+Ju1-js@users.noreply.github.com> Date: Tue, 28 Feb 2023 22:05:47 -0800 Subject: [PATCH 032/104] Slash was facing the wrong way --- webui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/webui.py b/webui.py index 5e925fa77..d60c4e5d1 100644 --- a/webui.py +++ b/webui.py @@ -209,7 +209,7 @@ def webui(): gradio_auth_creds = [] if cmd_opts.gradio_auth: - gradio_auth_creds += [x.strip() for x in cmd_opts.gradio_auth.strip('"').replace('/n', '').split(',') if x.strip()] + gradio_auth_creds += [x.strip() for x in cmd_opts.gradio_auth.strip('"').replace('\n', '').split(',') if x.strip()] if cmd_opts.gradio_auth_path: with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file: for line in file.readlines(): From b14d8b61bdc4cc2e6c0ea484a5ba8fd370231227 Mon Sep 17 00:00:00 2001 From: Adam Huganir Date: Wed, 1 Mar 2023 13:07:37 -0500 Subject: [PATCH 033/104] version bump for git python due to CVE-2022-24439 required version for CVE-2022-24439 is >= 3.130 --- requirements_versions.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements_versions.txt b/requirements_versions.txt index 331d0fe86..41e0ccc53 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -23,7 +23,7 @@ torchdiffeq==0.2.3 kornia==0.6.7 lark==1.1.2 inflection==0.5.1 -GitPython==3.1.27 +GitPython==3.1.30 torchsde==0.2.5 safetensors==0.2.7 httpcore<=0.15 From fc3063d9b924c094b59229269f4afe722b120d88 Mon Sep 17 00:00:00 2001 From: Ju1-js <40339350+Ju1-js@users.noreply.github.com> Date: Wed, 1 Mar 2023 18:25:23 -0800 Subject: [PATCH 034/104] Remove unnecessary line --- webui.py | 1 - 1 file changed, 1 deletion(-) diff --git a/webui.py b/webui.py index d60c4e5d1..be39fa8dc 100644 --- a/webui.py +++ b/webui.py @@ -214,7 +214,6 @@ def webui(): with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file: for line in file.readlines(): gradio_auth_creds += [x.strip() for x in line.split(',') if x.strip()] - file.close() app, local_url, share_url = shared.demo.launch( share=cmd_opts.share, From 23d4fb5bf2400622d00ca5fe489fadb160ee7c47 Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Fri, 3 Mar 2023 08:29:10 -0500 Subject: [PATCH 035/104] allow saving of images via api --- modules/api/api.py | 8 ++++---- modules/api/models.py | 4 ++-- modules/images.py | 3 +++ 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index 5a9ac5f1a..6b939daa9 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -180,8 +180,8 @@ class Api: populate = txt2imgreq.copy(update={ # Override __init__ params "sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index), - "do_not_save_samples": True, - "do_not_save_grid": True + "do_not_save_samples": True if not 'do_not_save_samples' in vars(txt2imgreq) else txt2imgreq.do_not_save_samples, + "do_not_save_grid": True if not 'do_not_save_grid' in vars(txt2imgreq) else txt2imgreq.do_not_save_grid, } ) if populate.sampler_name: @@ -220,8 +220,8 @@ class Api: populate = img2imgreq.copy(update={ # Override __init__ params "sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index), - "do_not_save_samples": True, - "do_not_save_grid": True, + "do_not_save_samples": True if not 'do_not_save_samples' in img2imgreq else img2imgreq.do_not_save_samples, + "do_not_save_grid": True if not 'do_not_save_grid' in img2imgreq else img2imgreq.do_not_save_grid, "mask": mask } ) diff --git a/modules/api/models.py b/modules/api/models.py index cba43d3b1..a947e6ac3 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -14,8 +14,8 @@ API_NOT_ALLOWED = [ "outpath_samples", "outpath_grids", "sampler_index", - "do_not_save_samples", - "do_not_save_grid", + # "do_not_save_samples", + # "do_not_save_grid", "extra_generation_params", "overlay_images", "do_not_reload_embeddings", diff --git a/modules/images.py b/modules/images.py index 5b80c23e1..f8e62b718 100644 --- a/modules/images.py +++ b/modules/images.py @@ -489,6 +489,9 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i """ namegen = FilenameGenerator(p, seed, prompt, image) + if path is None: # set default path to avoid errors when functions are triggered manually or via api and param is not set + path = opts.outdir_save + if save_to_dirs is None: save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt) From f8e219bad9f33cde94cd31fff3edd70946612541 Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Fri, 3 Mar 2023 09:00:52 -0500 Subject: [PATCH 036/104] allow api requests to specify do not send images in response --- modules/api/api.py | 10 ++++++++-- modules/api/models.py | 18 ++++++++++++++++-- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index 6b939daa9..7da9081b9 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -190,6 +190,9 @@ class Api: args = vars(populate) args.pop('script_name', None) + send_images = True if not 'do_not_send_images' in args else not args['do_not_send_images'] + args.pop('do_not_send_images', None) + with self.queue_lock: p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args) @@ -203,7 +206,7 @@ class Api: processed = process_images(p) shared.state.end() - b64images = list(map(encode_pil_to_base64, processed.images)) + b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else [] return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js()) @@ -232,6 +235,9 @@ class Api: args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine. args.pop('script_name', None) + send_images = True if not 'do_not_send_images' in args else not args['do_not_send_images'] + args.pop('do_not_send_images', None) + with self.queue_lock: p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args) p.init_images = [decode_base64_to_image(x) for x in init_images] @@ -246,7 +252,7 @@ class Api: processed = process_images(p) shared.state.end() - b64images = list(map(encode_pil_to_base64, processed.images)) + b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else [] if not img2imgreq.include_init_images: img2imgreq.init_images = None diff --git a/modules/api/models.py b/modules/api/models.py index a947e6ac3..aa4ea5d5f 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -100,13 +100,27 @@ class PydanticModelGenerator: StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator( "StableDiffusionProcessingTxt2Img", StableDiffusionProcessingTxt2Img, - [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}] + [ + {"key": "sampler_index", "type": str, "default": "Euler"}, + {"key": "script_name", "type": str, "default": None}, + {"key": "script_args", "type": list, "default": []}, + {"key": "do_not_send_images", "type": bool, "default": False} + ] ).generate_model() StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator( "StableDiffusionProcessingImg2Img", StableDiffusionProcessingImg2Img, - [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}] + [ + {"key": "sampler_index", "type": str, "default": "Euler"}, + {"key": "init_images", "type": list, "default": None}, + {"key": "denoising_strength", "type": float, "default": 0.75}, + {"key": "mask", "type": str, "default": None}, + {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}, + {"key": "script_name", "type": str, "default": None}, + {"key": "script_args", "type": list, "default": []}, + {"key": "do_not_send_images", "type": bool, "default": False} + ] ).generate_model() class TextToImageResponse(BaseModel): From c48bbccf12f13cf309f532d70494ff04c27bcc2a Mon Sep 17 00:00:00 2001 From: Yea chen Date: Sat, 4 Mar 2023 11:46:07 +0800 Subject: [PATCH 037/104] add: /sdapi/v1/scripts in API API for get scripts list --- modules/api/api.py | 13 +++++++++++++ modules/api/models.py | 4 ++++ 2 files changed, 17 insertions(+) diff --git a/modules/api/api.py b/modules/api/api.py index 5a9ac5f1a..46cb7c811 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -150,6 +150,7 @@ class Api: self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse) self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse) self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=MemoryResponse) + self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=ScriptsList) def add_api_route(self, path: str, endpoint, **kwargs): if shared.cmd_opts.api_auth: @@ -174,6 +175,18 @@ class Api: script_idx = script_name_to_index(script_name, script_runner.selectable_scripts) script = script_runner.selectable_scripts[script_idx] return script, script_idx + + def get_scripts_list(self): + t2ilist = [] + i2ilist = [] + + for a in scripts.scripts_txt2img.titles: + t2ilist.append(str(a.lower())) + + for b in scripts.scripts_img2img.titles: + i2ilist.append(str(b.lower())) + + return ScriptsList(txt2img = t2ilist, img2img = i2ilist) def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): script, script_idx = self.get_script(txt2imgreq.script_name, scripts.scripts_txt2img) diff --git a/modules/api/models.py b/modules/api/models.py index cba43d3b1..db739f2b4 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -267,3 +267,7 @@ class EmbeddingsResponse(BaseModel): class MemoryResponse(BaseModel): ram: dict = Field(title="RAM", description="System memory stats") cuda: dict = Field(title="CUDA", description="nVidia CUDA memory stats") + +class ScriptsList(BaseModel): + txt2img: list = Field(default=None,title="Txt2img", description="Titles of scripts (txt2img)") + img2img: list = Field(default=None,title="Img2img", description="Titles of scripts (img2img)") \ No newline at end of file From 2d9635cce5d8123c53e5c8233bab7c9751ae03ba Mon Sep 17 00:00:00 2001 From: DejitaruJin Date: Sat, 4 Mar 2023 12:51:55 -0500 Subject: [PATCH 038/104] Fix display and save order for X/Y/Z Grid script --- scripts/xyz_grid.py | 125 ++++++++++++++++++++++++-------------------- 1 file changed, 69 insertions(+), 56 deletions(-) diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 53511b121..8ede2aa0c 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -25,8 +25,6 @@ from modules.ui_components import ToolButton fill_values_symbol = "\U0001f4d2" # 📒 -AxisInfo = namedtuple('AxisInfo', ['axis', 'values']) - def apply_field(field): def fun(p, x, xs): @@ -188,7 +186,6 @@ axis_options = [ AxisOption("Steps", int, apply_field("steps")), AxisOptionTxt2Img("Hires steps", int, apply_field("hr_second_pass_steps")), AxisOption("CFG Scale", float, apply_field("cfg_scale")), - AxisOptionImg2Img("Image CFG Scale", float, apply_field("image_cfg_scale")), AxisOption("Prompt S/R", str, apply_prompt, format_value=format_value), AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list), AxisOptionTxt2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]), @@ -213,49 +210,47 @@ def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend ver_texts = [[images.GridAnnotation(y)] for y in y_labels] title_texts = [[images.GridAnnotation(z)] for z in z_labels] - # Temporary list of all the images that are generated to be populated into the grid. - # Will be filled with empty images for any individual step that fails to process properly - image_cache = [None] * (len(xs) * len(ys) * len(zs)) + list_size = (len(xs) * len(ys) * len(zs)) processed_result = None - cell_mode = "P" - cell_size = (1, 1) - state.job_count = len(xs) * len(ys) * len(zs) * p.n_iter + state.job_count = list_size * p.n_iter def process_cell(x, y, z, ix, iy, iz): - nonlocal image_cache, processed_result, cell_mode, cell_size + nonlocal processed_result def index(ix, iy, iz): return ix + iy * len(xs) + iz * len(xs) * len(ys) - state.job = f"{index(ix, iy, iz) + 1} out of {len(xs) * len(ys) * len(zs)}" + state.job = f"{index(ix, iy, iz) + 1} out of {list_size}" processed: Processed = cell(x, y, z) - try: - # this dereference will throw an exception if the image was not processed - # (this happens in cases such as if the user stops the process from the UI) - processed_image = processed.images[0] + if processed_result is None: + # Use our first processed result object as a template container to hold our full results + processed_result = copy(processed) + processed_result.images = [None] * list_size + processed_result.all_prompts = [None] * list_size + processed_result.all_seeds = [None] * list_size + processed_result.infotexts = [None] * list_size + processed_result.index_of_first_image = 0 - if processed_result is None: - # Use our first valid processed result as a template container to hold our full results - processed_result = copy(processed) - cell_mode = processed_image.mode - cell_size = processed_image.size - processed_result.images = [Image.new(cell_mode, cell_size)] - processed_result.all_prompts = [processed.prompt] - processed_result.all_seeds = [processed.seed] - processed_result.infotexts = [processed.infotexts[0]] + idx = index(ix, iy, iz) + if processed.images: + # Non-empty list indicates some degree of success. + processed_result.images[idx] = processed.images[0] + processed_result.all_prompts[idx] = processed.prompt + processed_result.all_seeds[idx] = processed.seed + processed_result.infotexts[idx] = processed.infotexts[0] + else: + cell_mode = "P" + cell_size = (processed_result.width, processed_result.height) + if processed_result.images[0] is not None: + cell_mode = processed_result.images[0].mode + #This corrects size in case of batches: + cell_size = processed_result.images[0].size + processed_result.images[idx] = Image.new(cell_mode, cell_size) - image_cache[index(ix, iy, iz)] = processed_image - if include_lone_images: - processed_result.images.append(processed_image) - processed_result.all_prompts.append(processed.prompt) - processed_result.all_seeds.append(processed.seed) - processed_result.infotexts.append(processed.infotexts[0]) - except: - image_cache[index(ix, iy, iz)] = Image.new(cell_mode, cell_size) if first_axes_processed == 'x': for ix, x in enumerate(xs): @@ -289,27 +284,36 @@ def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend process_cell(x, y, z, ix, iy, iz) if not processed_result: + # Should never happen, I've only seen it on one of four open tabs and it needed to refresh. + print("Unexpected error: Processing could not begin, you may need to refresh the tab or restart the service.") + return Processed(p, []) + elif not any(processed_result.images): print("Unexpected error: draw_xyz_grid failed to return even a single processed image") return Processed(p, []) - sub_grids = [None] * len(zs) - for i in range(len(zs)): - start_index = i * len(xs) * len(ys) + z_count = len(zs) + sub_grids = [None] * z_count + for i in range(z_count): + start_index = (i * len(xs) * len(ys)) + i end_index = start_index + len(xs) * len(ys) - grid = images.image_grid(image_cache[start_index:end_index], rows=len(ys)) + grid = images.image_grid(processed_result.images[start_index:end_index], rows=len(ys)) if draw_legend: - grid = images.draw_grid_annotations(grid, cell_size[0], cell_size[1], hor_texts, ver_texts, margin_size) - sub_grids[i] = grid - if include_sub_grids and len(zs) > 1: - processed_result.images.insert(i+1, grid) + grid = images.draw_grid_annotations(grid, processed_result.images[start_index].size[0], processed_result.images[start_index].size[1], hor_texts, ver_texts, margin_size) + processed_result.images.insert(i, grid) + processed_result.all_prompts.insert(i, processed_result.all_prompts[start_index]) + processed_result.all_seeds.insert(i, processed_result.all_seeds[start_index]) + processed_result.infotexts.insert(i, processed_result.infotexts[start_index]) - sub_grid_size = sub_grids[0].size - z_grid = images.image_grid(sub_grids, rows=1) + sub_grid_size = processed_result.images[0].size + z_grid = images.image_grid(processed_result.images[:z_count], rows=1) if draw_legend: z_grid = images.draw_grid_annotations(z_grid, sub_grid_size[0], sub_grid_size[1], title_texts, [[images.GridAnnotation()]]) - processed_result.images[0] = z_grid + processed_result.images.insert(0, z_grid) + processed_result.all_prompts.insert(0, processed_result.all_prompts[0]) + processed_result.all_seeds.insert(0, processed_result.all_seeds[0]) + processed_result.infotexts.insert(0, processed_result.infotexts[0]) - return processed_result, sub_grids + return processed_result class SharedSettingsStackHelper(object): @@ -364,7 +368,7 @@ class Script(scripts.Script): include_lone_images = gr.Checkbox(label='Include Sub Images', value=False, elem_id=self.elem_id("include_lone_images")) include_sub_grids = gr.Checkbox(label='Include Sub Grids', value=False, elem_id=self.elem_id("include_sub_grids")) with gr.Column(): - margin_size = gr.Slider(label="Grid margins (px)", minimum=0, maximum=500, value=0, step=2, elem_id=self.elem_id("margin_size")) + margin_size = gr.Slider(label="Grid margins (px)", min=0, max=500, value=0, step=2, elem_id=self.elem_id("margin_size")) with gr.Row(variant="compact", elem_id="swap_axes"): swap_xy_axes_button = gr.Button(value="Swap X/Y axes", elem_id="xy_grid_swap_axes_button") @@ -526,14 +530,10 @@ class Script(scripts.Script): grid_infotext = [None] - state.xyz_plot_x = AxisInfo(x_opt, xs) - state.xyz_plot_y = AxisInfo(y_opt, ys) - state.xyz_plot_z = AxisInfo(z_opt, zs) - # If one of the axes is very slow to change between (like SD model # checkpoint), then make sure it is in the outer iteration of the nested # `for` loop. - first_axes_processed = 'x' + first_axes_processed = 'z' second_axes_processed = 'y' if x_opt.cost > y_opt.cost and x_opt.cost > z_opt.cost: first_axes_processed = 'x' @@ -593,7 +593,7 @@ class Script(scripts.Script): return res with SharedSettingsStackHelper(): - processed, sub_grids = draw_xyz_grid( + processed = draw_xyz_grid( p, xs=xs, ys=ys, @@ -610,11 +610,24 @@ class Script(scripts.Script): margin_size=margin_size ) - if opts.grid_save and len(sub_grids) > 1: - for sub_grid in sub_grids: - images.save_image(sub_grid, p.outpath_grids, "xyz_grid", info=grid_infotext[0], extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p) + z_count = len(zs) - if opts.grid_save: - images.save_image(processed.images[0], p.outpath_grids, "xyz_grid", info=grid_infotext[0], extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p) + if not include_lone_images: + # Don't need sub-images anymore, drop from list: + processed.images = processed.images[:z_count+1] + + if opts.grid_save and processed.images: + # Auto-save main and sub-grids: + grid_count = z_count + 1 if z_count > 1 else 1 + for g in range(grid_count): + images.save_image(processed.images[g], p.outpath_grids, "xyz_grid", info=processed.infotexts[g], extension=opts.grid_format, prompt=processed.all_prompts[g], seed=processed.all_seeds[g], grid=True, p=processed) + + if not include_sub_grids: + # Done with sub-grids, drop all related information: + for sg in range(z_count): + del processed.images[1] + del processed.all_prompts[1] + del processed.all_seeds[1] + del processed.infotexts[1] return processed From 2ba880704b970f2870cbae1fe08cea77a21b9213 Mon Sep 17 00:00:00 2001 From: DejitaruJin Date: Sat, 4 Mar 2023 13:00:27 -0500 Subject: [PATCH 039/104] Add files via upload --- scripts/xyz_grid.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 8ede2aa0c..e90108171 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -25,6 +25,8 @@ from modules.ui_components import ToolButton fill_values_symbol = "\U0001f4d2" # 📒 +AxisInfo = namedtuple('AxisInfo', ['axis', 'values']) + def apply_field(field): def fun(p, x, xs): @@ -186,6 +188,7 @@ axis_options = [ AxisOption("Steps", int, apply_field("steps")), AxisOptionTxt2Img("Hires steps", int, apply_field("hr_second_pass_steps")), AxisOption("CFG Scale", float, apply_field("cfg_scale")), + AxisOptionImg2Img("Image CFG Scale", float, apply_field("image_cfg_scale")), AxisOption("Prompt S/R", str, apply_prompt, format_value=format_value), AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list), AxisOptionTxt2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]), @@ -368,7 +371,7 @@ class Script(scripts.Script): include_lone_images = gr.Checkbox(label='Include Sub Images', value=False, elem_id=self.elem_id("include_lone_images")) include_sub_grids = gr.Checkbox(label='Include Sub Grids', value=False, elem_id=self.elem_id("include_sub_grids")) with gr.Column(): - margin_size = gr.Slider(label="Grid margins (px)", min=0, max=500, value=0, step=2, elem_id=self.elem_id("margin_size")) + margin_size = gr.Slider(label="Grid margins (px)", minimum=0, maximum=500, value=0, step=2, elem_id=self.elem_id("margin_size")) with gr.Row(variant="compact", elem_id="swap_axes"): swap_xy_axes_button = gr.Button(value="Swap X/Y axes", elem_id="xy_grid_swap_axes_button") @@ -530,6 +533,10 @@ class Script(scripts.Script): grid_infotext = [None] + state.xyz_plot_x = AxisInfo(x_opt, xs) + state.xyz_plot_y = AxisInfo(y_opt, ys) + state.xyz_plot_z = AxisInfo(z_opt, zs) + # If one of the axes is very slow to change between (like SD model # checkpoint), then make sure it is in the outer iteration of the nested # `for` loop. From fe7d7dfd5ae9fdb09eea56af48c45ddc76fa3e28 Mon Sep 17 00:00:00 2001 From: DejitaruJin Date: Sat, 4 Mar 2023 15:40:35 -0500 Subject: [PATCH 040/104] Add files via upload --- scripts/xyz_grid.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index e90108171..1cce87e12 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -236,7 +236,7 @@ def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend processed_result.all_prompts = [None] * list_size processed_result.all_seeds = [None] * list_size processed_result.infotexts = [None] * list_size - processed_result.index_of_first_image = 0 + processed_result.index_of_first_image = 1 idx = index(ix, iy, iz) if processed.images: @@ -312,8 +312,9 @@ def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend if draw_legend: z_grid = images.draw_grid_annotations(z_grid, sub_grid_size[0], sub_grid_size[1], title_texts, [[images.GridAnnotation()]]) processed_result.images.insert(0, z_grid) - processed_result.all_prompts.insert(0, processed_result.all_prompts[0]) - processed_result.all_seeds.insert(0, processed_result.all_seeds[0]) + #TODO: Deeper aspects of the program rely on index 0 "grid" images only having partial information, which is not ideal. + #processed_result.all_prompts.insert(0, processed_result.all_prompts[0]) + #processed_result.all_seeds.insert(0, processed_result.all_seeds[0]) processed_result.infotexts.insert(0, processed_result.infotexts[0]) return processed_result @@ -637,4 +638,6 @@ class Script(scripts.Script): del processed.all_seeds[1] del processed.infotexts[1] + print(processed.images) + return processed From eb29ff211af885a96cee3a97beb99194a6b22a3d Mon Sep 17 00:00:00 2001 From: DejitaruJin Date: Sat, 4 Mar 2023 16:06:40 -0500 Subject: [PATCH 041/104] Add files via upload --- scripts/xyz_grid.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 1cce87e12..7ed8a9da2 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -312,7 +312,7 @@ def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend if draw_legend: z_grid = images.draw_grid_annotations(z_grid, sub_grid_size[0], sub_grid_size[1], title_texts, [[images.GridAnnotation()]]) processed_result.images.insert(0, z_grid) - #TODO: Deeper aspects of the program rely on index 0 "grid" images only having partial information, which is not ideal. + #TODO: Deeper aspects of the program rely on grid info being misaligned between metadata arrays, which is not ideal. #processed_result.all_prompts.insert(0, processed_result.all_prompts[0]) #processed_result.all_seeds.insert(0, processed_result.all_seeds[0]) processed_result.infotexts.insert(0, processed_result.infotexts[0]) @@ -628,7 +628,9 @@ class Script(scripts.Script): # Auto-save main and sub-grids: grid_count = z_count + 1 if z_count > 1 else 1 for g in range(grid_count): - images.save_image(processed.images[g], p.outpath_grids, "xyz_grid", info=processed.infotexts[g], extension=opts.grid_format, prompt=processed.all_prompts[g], seed=processed.all_seeds[g], grid=True, p=processed) + #TODO: See previous comment about intentional data misalignment. + adj_g = g-1 if g > 0 else g + images.save_image(processed.images[g], p.outpath_grids, "xyz_grid", info=processed.infotexts[g], extension=opts.grid_format, prompt=processed.all_prompts[adj_g], seed=processed.all_seeds[adj_g], grid=True, p=processed) if not include_sub_grids: # Done with sub-grids, drop all related information: @@ -638,6 +640,4 @@ class Script(scripts.Script): del processed.all_seeds[1] del processed.infotexts[1] - print(processed.images) - return processed From b012d70f15641d6b85c9257b83cec892e941609c Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Sat, 4 Mar 2023 17:51:37 -0500 Subject: [PATCH 042/104] update using original defaults --- modules/api/api.py | 17 +++++++++++------ modules/api/models.py | 6 ++++-- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index 7da9081b9..a6bb439c7 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -180,8 +180,8 @@ class Api: populate = txt2imgreq.copy(update={ # Override __init__ params "sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index), - "do_not_save_samples": True if not 'do_not_save_samples' in vars(txt2imgreq) else txt2imgreq.do_not_save_samples, - "do_not_save_grid": True if not 'do_not_save_grid' in vars(txt2imgreq) else txt2imgreq.do_not_save_grid, + "do_not_save_samples": txt2imgreq.do_not_save, + "do_not_save_grid": txt2imgreq.do_not_save, } ) if populate.sampler_name: @@ -190,8 +190,9 @@ class Api: args = vars(populate) args.pop('script_name', None) - send_images = True if not 'do_not_send_images' in args else not args['do_not_send_images'] - args.pop('do_not_send_images', None) + send_images = True if not 'do_not_send' in args else not args['do_not_send'] + args.pop('do_not_send', None) + args.pop('do_not_save', None) with self.queue_lock: p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args) @@ -223,8 +224,8 @@ class Api: populate = img2imgreq.copy(update={ # Override __init__ params "sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index), - "do_not_save_samples": True if not 'do_not_save_samples' in img2imgreq else img2imgreq.do_not_save_samples, - "do_not_save_grid": True if not 'do_not_save_grid' in img2imgreq else img2imgreq.do_not_save_grid, + "do_not_save_samples": img2imgreq.do_not_save, + "do_not_save_grid": img2imgreq.do_not_save, "mask": mask } ) @@ -235,6 +236,10 @@ class Api: args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine. args.pop('script_name', None) + send_images = True if not 'do_not_send' in args else not args['do_not_send'] + args.pop('do_not_send', None) + args.pop('do_not_save', None) + send_images = True if not 'do_not_send_images' in args else not args['do_not_send_images'] args.pop('do_not_send_images', None) diff --git a/modules/api/models.py b/modules/api/models.py index aa4ea5d5f..2b66e1f03 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -104,7 +104,8 @@ StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator( {"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}, - {"key": "do_not_send_images", "type": bool, "default": False} + {"key": "do_not_send", "type": bool, "default": False}, + {"key": "do_not_save", "type": bool, "default": True} ] ).generate_model() @@ -119,7 +120,8 @@ StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator( {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}, - {"key": "do_not_send_images", "type": bool, "default": False} + {"key": "do_not_send", "type": bool, "default": False}, + {"key": "do_not_save", "type": bool, "default": True} ] ).generate_model() From c8b52c79755618736aec40a80d72043967274a59 Mon Sep 17 00:00:00 2001 From: DejitaruJin Date: Sat, 4 Mar 2023 19:32:09 -0500 Subject: [PATCH 043/104] Short-circuit error handling --- scripts/xyz_grid.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 7ed8a9da2..f79c46f6e 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -618,13 +618,17 @@ class Script(scripts.Script): margin_size=margin_size ) + if not processed.images: + # It broke, no further handling needed. + return processed + z_count = len(zs) if not include_lone_images: # Don't need sub-images anymore, drop from list: processed.images = processed.images[:z_count+1] - if opts.grid_save and processed.images: + if opts.grid_save: # Auto-save main and sub-grids: grid_count = z_count + 1 if z_count > 1 else 1 for g in range(grid_count): From d118cb6ea3f1a410b5e030519dc021eafc1d6b52 Mon Sep 17 00:00:00 2001 From: Brad Smith Date: Mon, 6 Mar 2023 13:18:35 -0500 Subject: [PATCH 044/104] use lowercase name for sorting; keep `UpscalerLanczos` and `UpscalerNearest` at the start of the list with `UpscalerNone` Co-authored-by: catboxanon <122327233+catboxanon@users.noreply.github.com> --- modules/modelloader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/modelloader.py b/modules/modelloader.py index a7ac338c2..e351d808a 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -6,7 +6,7 @@ from urllib.parse import urlparse from basicsr.utils.download_util import load_file_from_url from modules import shared -from modules.upscaler import Upscaler, UpscalerNone +from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone from modules.paths import script_path, models_path @@ -172,5 +172,5 @@ def load_upscalers(): shared.sd_upscalers = sorted( datas, # Special case for UpscalerNone keeps it at the beginning of the list. - key=lambda x: x.name if not isinstance(x.scaler, UpscalerNone) else "" + key=lambda x: x.name.lower() if not isinstance(x.scaler, (UpscalerNone, UpscalerLanczos, UpscalerNearest)) else "" ) From 49b1dc5e07825e76c85ac4ac078fd63aa835e8bd Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 6 Mar 2023 21:00:34 +0200 Subject: [PATCH 045/104] Deduplicate extra network preview-search code --- extensions-builtin/Lora/ui_extra_networks_lora.py | 10 +--------- modules/ui_extra_networks.py | 10 ++++++++++ modules/ui_extra_networks_checkpoints.py | 11 +---------- modules/ui_extra_networks_hypernets.py | 9 +-------- modules/ui_extra_networks_textual_inversion.py | 8 +------- 5 files changed, 14 insertions(+), 34 deletions(-) diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index 22cabcb0f..4c1549d73 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -15,18 +15,10 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): def list_items(self): for name, lora_on_disk in lora.available_loras.items(): path, ext = os.path.splitext(lora_on_disk.filename) - previews = [path + ".png", path + ".preview.png"] - - preview = None - for file in previews: - if os.path.isfile(file): - preview = self.link_preview(file) - break - yield { "name": name, "filename": path, - "preview": preview, + "preview": self._find_preview(path), "search_term": self.search_terms_from_path(lora_on_disk.filename), "prompt": json.dumps(f""), "local_preview": path + ".png", diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 71f1d81f2..1a10a5df7 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -2,6 +2,7 @@ import glob import os.path import urllib.parse from pathlib import Path +from typing import Optional from modules import shared import gradio as gr @@ -137,6 +138,15 @@ class ExtraNetworksPage: return self.card_page.format(**args) + def _find_preview(self, path: str) -> Optional[str]: + """ + Find a preview PNG for a given path (without extension) and call link_preview on it. + """ + for file in [path + ".png", path + ".preview.png"]: + if os.path.isfile(file): + return self.link_preview(file) + return None + def intialize(): extra_pages.clear() diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py index 04097a794..b712d12b7 100644 --- a/modules/ui_extra_networks_checkpoints.py +++ b/modules/ui_extra_networks_checkpoints.py @@ -1,7 +1,6 @@ import html import json import os -import urllib.parse from modules import shared, ui_extra_networks, sd_models @@ -17,18 +16,10 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): checkpoint: sd_models.CheckpointInfo for name, checkpoint in sd_models.checkpoints_list.items(): path, ext = os.path.splitext(checkpoint.filename) - previews = [path + ".png", path + ".preview.png"] - - preview = None - for file in previews: - if os.path.isfile(file): - preview = self.link_preview(file) - break - yield { "name": checkpoint.name_for_extra, "filename": path, - "preview": preview, + "preview": self._find_preview(path), "search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""), "onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"', "local_preview": path + ".png", diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py index 578510887..89f33242f 100644 --- a/modules/ui_extra_networks_hypernets.py +++ b/modules/ui_extra_networks_hypernets.py @@ -14,18 +14,11 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): def list_items(self): for name, path in shared.hypernetworks.items(): path, ext = os.path.splitext(path) - previews = [path + ".png", path + ".preview.png"] - - preview = None - for file in previews: - if os.path.isfile(file): - preview = self.link_preview(file) - break yield { "name": name, "filename": path, - "preview": preview, + "preview": self._find_preview(path), "search_term": self.search_terms_from_path(path), "prompt": json.dumps(f""), "local_preview": path + ".png", diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py index bb64eb81e..f7057390d 100644 --- a/modules/ui_extra_networks_textual_inversion.py +++ b/modules/ui_extra_networks_textual_inversion.py @@ -15,16 +15,10 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): def list_items(self): for embedding in sd_hijack.model_hijack.embedding_db.word_embeddings.values(): path, ext = os.path.splitext(embedding.filename) - preview_file = path + ".preview.png" - - preview = None - if os.path.isfile(preview_file): - preview = self.link_preview(preview_file) - yield { "name": embedding.name, "filename": embedding.filename, - "preview": preview, + "preview": self._find_preview(path), "search_term": self.search_terms_from_path(embedding.filename), "prompt": json.dumps(embedding.name), "local_preview": path + ".preview.png", From 06f167da37cd00ea8241bd2a6a3c12d8c5fb9eaf Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 6 Mar 2023 21:14:06 +0200 Subject: [PATCH 046/104] Extra networks: support .txt description sidecar file --- extensions-builtin/Lora/ui_extra_networks_lora.py | 1 + html/extra-networks-card.html | 1 + modules/ui_extra_networks.py | 15 +++++++++++++++ modules/ui_extra_networks_checkpoints.py | 1 + modules/ui_extra_networks_hypernets.py | 1 + modules/ui_extra_networks_textual_inversion.py | 1 + style.css | 11 +++++++++++ 7 files changed, 31 insertions(+) diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index 4c1549d73..9da13a090 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -19,6 +19,7 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): "name": name, "filename": path, "preview": self._find_preview(path), + "description": self._find_description(path), "search_term": self.search_terms_from_path(lora_on_disk.filename), "prompt": json.dumps(f""), "local_preview": path + ".png", diff --git a/html/extra-networks-card.html b/html/extra-networks-card.html index 8a5e2fbd2..8612396d2 100644 --- a/html/extra-networks-card.html +++ b/html/extra-networks-card.html @@ -7,6 +7,7 @@ {name} + {description} diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 1a10a5df7..cd61a5694 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -1,6 +1,7 @@ import glob import os.path import urllib.parse +from functools import lru_cache from pathlib import Path from typing import Optional @@ -131,6 +132,7 @@ class ExtraNetworksPage: "tabname": json.dumps(tabname), "local_preview": json.dumps(item["local_preview"]), "name": item["name"], + "description": (item.get("description") or ""), "card_clicked": onclick, "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"', "search_term": item.get("search_term", ""), @@ -147,6 +149,19 @@ class ExtraNetworksPage: return self.link_preview(file) return None + @lru_cache(maxsize=512) + def _find_description(self, path: str) -> Optional[str]: + """ + Find and read a description file for a given path (without extension). + """ + for file in [f"{path}.txt", f"{path}.description.txt"]: + try: + with open(file, "r", encoding="utf-8", errors="replace") as f: + return f.read() + except OSError: + pass + return None + def intialize(): extra_pages.clear() diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py index b712d12b7..1deb785aa 100644 --- a/modules/ui_extra_networks_checkpoints.py +++ b/modules/ui_extra_networks_checkpoints.py @@ -20,6 +20,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): "name": checkpoint.name_for_extra, "filename": path, "preview": self._find_preview(path), + "description": self._find_description(path), "search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""), "onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"', "local_preview": path + ".png", diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py index 89f33242f..80cc2a248 100644 --- a/modules/ui_extra_networks_hypernets.py +++ b/modules/ui_extra_networks_hypernets.py @@ -19,6 +19,7 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): "name": name, "filename": path, "preview": self._find_preview(path), + "description": self._find_description(path), "search_term": self.search_terms_from_path(path), "prompt": json.dumps(f""), "local_preview": path + ".png", diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py index f7057390d..f3bae6665 100644 --- a/modules/ui_extra_networks_textual_inversion.py +++ b/modules/ui_extra_networks_textual_inversion.py @@ -19,6 +19,7 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): "name": embedding.name, "filename": embedding.filename, "preview": self._find_preview(path), + "description": self._find_description(path), "search_term": self.search_terms_from_path(embedding.filename), "prompt": json.dumps(embedding.name), "local_preview": path + ".preview.png", diff --git a/style.css b/style.css index 05572f662..9f2af2638 100644 --- a/style.css +++ b/style.css @@ -939,6 +939,17 @@ footer { line-break: anywhere; } +.extra-network-cards .card .actions .description { + display: block; + max-height: 3em; + white-space: pre-wrap; + line-height: 1.1; +} + +.extra-network-cards .card .actions .description:hover { + max-height: none; +} + .extra-network-cards .card .actions:hover .additional{ display: block; } From fec0a895119a124a295e3dad5205de5766031dc7 Mon Sep 17 00:00:00 2001 From: Pam Date: Tue, 7 Mar 2023 00:33:13 +0500 Subject: [PATCH 047/104] scaled dot product attention --- html/licenses.html | 219 +++++++++++++++++++++++++++++ modules/sd_hijack.py | 4 + modules/sd_hijack_optimizations.py | 42 ++++++ modules/shared.py | 1 + 4 files changed, 266 insertions(+) diff --git a/html/licenses.html b/html/licenses.html index 570630eb4..bddbf4665 100644 --- a/html/licenses.html +++ b/html/licenses.html @@ -417,3 +417,222 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +

Scaled Dot Product Attention

+Some small amounts of code borrowed and reworked. +
+   Copyright 2023 The HuggingFace Team. All rights reserved.
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+      http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License.
+
+                                 Apache License
+                           Version 2.0, January 2004
+                        http://www.apache.org/licenses/
+
+   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+   1. Definitions.
+
+      "License" shall mean the terms and conditions for use, reproduction,
+      and distribution as defined by Sections 1 through 9 of this document.
+
+      "Licensor" shall mean the copyright owner or entity authorized by
+      the copyright owner that is granting the License.
+
+      "Legal Entity" shall mean the union of the acting entity and all
+      other entities that control, are controlled by, or are under common
+      control with that entity. For the purposes of this definition,
+      "control" means (i) the power, direct or indirect, to cause the
+      direction or management of such entity, whether by contract or
+      otherwise, or (ii) ownership of fifty percent (50%) or more of the
+      outstanding shares, or (iii) beneficial ownership of such entity.
+
+      "You" (or "Your") shall mean an individual or Legal Entity
+      exercising permissions granted by this License.
+
+      "Source" form shall mean the preferred form for making modifications,
+      including but not limited to software source code, documentation
+      source, and configuration files.
+
+      "Object" form shall mean any form resulting from mechanical
+      transformation or translation of a Source form, including but
+      not limited to compiled object code, generated documentation,
+      and conversions to other media types.
+
+      "Work" shall mean the work of authorship, whether in Source or
+      Object form, made available under the License, as indicated by a
+      copyright notice that is included in or attached to the work
+      (an example is provided in the Appendix below).
+
+      "Derivative Works" shall mean any work, whether in Source or Object
+      form, that is based on (or derived from) the Work and for which the
+      editorial revisions, annotations, elaborations, or other modifications
+      represent, as a whole, an original work of authorship. For the purposes
+      of this License, Derivative Works shall not include works that remain
+      separable from, or merely link (or bind by name) to the interfaces of,
+      the Work and Derivative Works thereof.
+
+      "Contribution" shall mean any work of authorship, including
+      the original version of the Work and any modifications or additions
+      to that Work or Derivative Works thereof, that is intentionally
+      submitted to Licensor for inclusion in the Work by the copyright owner
+      or by an individual or Legal Entity authorized to submit on behalf of
+      the copyright owner. For the purposes of this definition, "submitted"
+      means any form of electronic, verbal, or written communication sent
+      to the Licensor or its representatives, including but not limited to
+      communication on electronic mailing lists, source code control systems,
+      and issue tracking systems that are managed by, or on behalf of, the
+      Licensor for the purpose of discussing and improving the Work, but
+      excluding communication that is conspicuously marked or otherwise
+      designated in writing by the copyright owner as "Not a Contribution."
+
+      "Contributor" shall mean Licensor and any individual or Legal Entity
+      on behalf of whom a Contribution has been received by Licensor and
+      subsequently incorporated within the Work.
+
+   2. Grant of Copyright License. Subject to the terms and conditions of
+      this License, each Contributor hereby grants to You a perpetual,
+      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+      copyright license to reproduce, prepare Derivative Works of,
+      publicly display, publicly perform, sublicense, and distribute the
+      Work and such Derivative Works in Source or Object form.
+
+   3. Grant of Patent License. Subject to the terms and conditions of
+      this License, each Contributor hereby grants to You a perpetual,
+      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+      (except as stated in this section) patent license to make, have made,
+      use, offer to sell, sell, import, and otherwise transfer the Work,
+      where such license applies only to those patent claims licensable
+      by such Contributor that are necessarily infringed by their
+      Contribution(s) alone or by combination of their Contribution(s)
+      with the Work to which such Contribution(s) was submitted. If You
+      institute patent litigation against any entity (including a
+      cross-claim or counterclaim in a lawsuit) alleging that the Work
+      or a Contribution incorporated within the Work constitutes direct
+      or contributory patent infringement, then any patent licenses
+      granted to You under this License for that Work shall terminate
+      as of the date such litigation is filed.
+
+   4. Redistribution. You may reproduce and distribute copies of the
+      Work or Derivative Works thereof in any medium, with or without
+      modifications, and in Source or Object form, provided that You
+      meet the following conditions:
+
+      (a) You must give any other recipients of the Work or
+          Derivative Works a copy of this License; and
+
+      (b) You must cause any modified files to carry prominent notices
+          stating that You changed the files; and
+
+      (c) You must retain, in the Source form of any Derivative Works
+          that You distribute, all copyright, patent, trademark, and
+          attribution notices from the Source form of the Work,
+          excluding those notices that do not pertain to any part of
+          the Derivative Works; and
+
+      (d) If the Work includes a "NOTICE" text file as part of its
+          distribution, then any Derivative Works that You distribute must
+          include a readable copy of the attribution notices contained
+          within such NOTICE file, excluding those notices that do not
+          pertain to any part of the Derivative Works, in at least one
+          of the following places: within a NOTICE text file distributed
+          as part of the Derivative Works; within the Source form or
+          documentation, if provided along with the Derivative Works; or,
+          within a display generated by the Derivative Works, if and
+          wherever such third-party notices normally appear. The contents
+          of the NOTICE file are for informational purposes only and
+          do not modify the License. You may add Your own attribution
+          notices within Derivative Works that You distribute, alongside
+          or as an addendum to the NOTICE text from the Work, provided
+          that such additional attribution notices cannot be construed
+          as modifying the License.
+
+      You may add Your own copyright statement to Your modifications and
+      may provide additional or different license terms and conditions
+      for use, reproduction, or distribution of Your modifications, or
+      for any such Derivative Works as a whole, provided Your use,
+      reproduction, and distribution of the Work otherwise complies with
+      the conditions stated in this License.
+
+   5. Submission of Contributions. Unless You explicitly state otherwise,
+      any Contribution intentionally submitted for inclusion in the Work
+      by You to the Licensor shall be under the terms and conditions of
+      this License, without any additional terms or conditions.
+      Notwithstanding the above, nothing herein shall supersede or modify
+      the terms of any separate license agreement you may have executed
+      with Licensor regarding such Contributions.
+
+   6. Trademarks. This License does not grant permission to use the trade
+      names, trademarks, service marks, or product names of the Licensor,
+      except as required for reasonable and customary use in describing the
+      origin of the Work and reproducing the content of the NOTICE file.
+
+   7. Disclaimer of Warranty. Unless required by applicable law or
+      agreed to in writing, Licensor provides the Work (and each
+      Contributor provides its Contributions) on an "AS IS" BASIS,
+      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+      implied, including, without limitation, any warranties or conditions
+      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+      PARTICULAR PURPOSE. You are solely responsible for determining the
+      appropriateness of using or redistributing the Work and assume any
+      risks associated with Your exercise of permissions under this License.
+
+   8. Limitation of Liability. In no event and under no legal theory,
+      whether in tort (including negligence), contract, or otherwise,
+      unless required by applicable law (such as deliberate and grossly
+      negligent acts) or agreed to in writing, shall any Contributor be
+      liable to You for damages, including any direct, indirect, special,
+      incidental, or consequential damages of any character arising as a
+      result of this License or out of the use or inability to use the
+      Work (including but not limited to damages for loss of goodwill,
+      work stoppage, computer failure or malfunction, or any and all
+      other commercial damages or losses), even if such Contributor
+      has been advised of the possibility of such damages.
+
+   9. Accepting Warranty or Additional Liability. While redistributing
+      the Work or Derivative Works thereof, You may choose to offer,
+      and charge a fee for, acceptance of support, warranty, indemnity,
+      or other liability obligations and/or rights consistent with this
+      License. However, in accepting such obligations, You may act only
+      on Your own behalf and on Your sole responsibility, not on behalf
+      of any other Contributor, and only if You agree to indemnify,
+      defend, and hold each Contributor harmless for any liability
+      incurred by, or claims asserted against, such Contributor by reason
+      of your accepting any such warranty or additional liability.
+
+   END OF TERMS AND CONDITIONS
+
+   APPENDIX: How to apply the Apache License to your work.
+
+      To apply the Apache License to your work, attach the following
+      boilerplate notice, with the fields enclosed by brackets "[]"
+      replaced with your own identifying information. (Don't include
+      the brackets!)  The text should be enclosed in the appropriate
+      comment syntax for the file format. We also recommend that a
+      file or class name and description of purpose be included on the
+      same "printed page" as the copyright notice for easier
+      identification within third-party archives.
+
+   Copyright [yyyy] [name of copyright owner]
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+       http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License.
+
\ No newline at end of file diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 794767831..76cb91209 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -42,6 +42,10 @@ def apply_optimizations(): ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward optimization_method = 'xformers' + elif cmd_opts.opt_sdp_attention and (hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(getattr(torch.nn.functional, "scaled_dot_product_attention"))): + print("Applying scaled dot product cross attention optimization.") + ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_attention_forward + optimization_method = 'sdp' elif cmd_opts.opt_sub_quad_attention: print("Applying sub-quadratic cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index c02d954c7..a324a5927 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -346,6 +346,48 @@ def xformers_attention_forward(self, x, context=None, mask=None): out = rearrange(out, 'b n h d -> b n (h d)', h=h) return self.to_out(out) +# Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py +# The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface +def scaled_dot_product_attention_forward(self, x, context=None, mask=None): + batch_size, sequence_length, inner_dim = x.shape + + if mask is not None: + mask = self.prepare_attention_mask(mask, sequence_length, batch_size) + mask = mask.view(batch_size, self.heads, -1, mask.shape[-1]) + + h = self.heads + q_in = self.to_q(x) + context = default(context, x) + + context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) + k_in = self.to_k(context_k) + v_in = self.to_v(context_v) + + head_dim = inner_dim // h + q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2) + k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2) + v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2) + + del q_in, k_in, v_in + + dtype = q.dtype + if shared.opts.upcast_attn: + q, k = q.float(), k.float() + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + hidden_states = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, h * head_dim) + hidden_states = hidden_states.to(dtype) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + # dropout + hidden_states = self.to_out[1](hidden_states) + return hidden_states + def cross_attention_attnblock_forward(self, x): h_ = x h_ = self.norm(h_) diff --git a/modules/shared.py b/modules/shared.py index 805f9cc19..12d0756bd 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -69,6 +69,7 @@ parser.add_argument("--sub-quad-kv-chunk-size", type=int, help="kv chunk size fo parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage of VRAM threshold for the sub-quadratic cross-attention layer optimization to use chunking", default=None) parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") +parser.add_argument("--opt-sdp-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization; requires PyTorch 2.*") parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI") parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower) From f85a192f998dcc17d0a59d8755c76100e8e31bde Mon Sep 17 00:00:00 2001 From: Yea Chen Date: Tue, 7 Mar 2023 04:04:35 +0800 Subject: [PATCH 048/104] Update modules/api/api.py Suggested change by @akx Co-authored-by: Aarni Koskela --- modules/api/api.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index 46cb7c811..12b383861 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -177,14 +177,8 @@ class Api: return script, script_idx def get_scripts_list(self): - t2ilist = [] - i2ilist = [] - - for a in scripts.scripts_txt2img.titles: - t2ilist.append(str(a.lower())) - - for b in scripts.scripts_img2img.titles: - i2ilist.append(str(b.lower())) + t2ilist = [str(title.lower()) for title in scripts.scripts_txt2img.titles] + i2ilist = [str(title.lower()) for title in scripts.scripts_img2img.titles] return ScriptsList(txt2img = t2ilist, img2img = i2ilist) From 09c73710c9145afdd22bcdb6da68db8e346e35b6 Mon Sep 17 00:00:00 2001 From: vladlearns Date: Wed, 8 Mar 2023 23:00:55 +0200 Subject: [PATCH 049/104] chore: auto update all extensions using scripts --- extensions/update-all.bat | 1 + extensions/update-all.sh | 3 +++ 2 files changed, 4 insertions(+) create mode 100644 extensions/update-all.bat create mode 100644 extensions/update-all.sh diff --git a/extensions/update-all.bat b/extensions/update-all.bat new file mode 100644 index 000000000..75a17cbe7 --- /dev/null +++ b/extensions/update-all.bat @@ -0,0 +1 @@ +for /d %%i in (*) do @if exist "%%i\.git" (echo Pulling updates for %%i... & git -C "%%i" pull) \ No newline at end of file diff --git a/extensions/update-all.sh b/extensions/update-all.sh new file mode 100644 index 000000000..b00de9278 --- /dev/null +++ b/extensions/update-all.sh @@ -0,0 +1,3 @@ +ls | while read dir; do if [ -d "$dir/.git" ]; +then echo "Pulling updates for $dir..."; +git -C "$dir" pull; fi; done \ No newline at end of file From b07b7057f0636aa142471ec27841a8001a85f98b Mon Sep 17 00:00:00 2001 From: vladlearns Date: Thu, 9 Mar 2023 16:25:18 +0200 Subject: [PATCH 050/104] chore: removed scripts and added a flag to launch.py --- extensions/update-all.bat | 1 - extensions/update-all.sh | 3 --- launch.py | 14 +++++++++++++- 3 files changed, 13 insertions(+), 5 deletions(-) delete mode 100644 extensions/update-all.bat delete mode 100644 extensions/update-all.sh diff --git a/extensions/update-all.bat b/extensions/update-all.bat deleted file mode 100644 index 75a17cbe7..000000000 --- a/extensions/update-all.bat +++ /dev/null @@ -1 +0,0 @@ -for /d %%i in (*) do @if exist "%%i\.git" (echo Pulling updates for %%i... & git -C "%%i" pull) \ No newline at end of file diff --git a/extensions/update-all.sh b/extensions/update-all.sh deleted file mode 100644 index b00de9278..000000000 --- a/extensions/update-all.sh +++ /dev/null @@ -1,3 +0,0 @@ -ls | while read dir; do if [ -d "$dir/.git" ]; -then echo "Pulling updates for $dir..."; -git -C "$dir" pull; fi; done \ No newline at end of file diff --git a/launch.py b/launch.py index a68bb3a91..ba3067915 100644 --- a/launch.py +++ b/launch.py @@ -161,7 +161,15 @@ def git_clone(url, dir, name, commithash=None): if commithash is not None: run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}") - +def git_pull_recursive(dir): + for subdir, _, _ in os.walk(dir): + if os.path.exists(os.path.join(subdir, '.git')): + try: + output = subprocess.check_output(['git', '-C', subdir, 'pull']) + print(f"Pulled changes for repository in '{subdir}':\n{output.decode('utf-8').strip()}\n") + except subprocess.CalledProcessError as e: + print(f"Couldn't perform 'git pull' on repository in '{subdir}':\n{e.output.decode('utf-8').strip()}\n") + def version_check(commit): try: import requests @@ -247,6 +255,7 @@ def prepare_environment(): args, _ = parser.parse_known_args(sys.argv) sys.argv, _ = extract_arg(sys.argv, '-f') + sys.argv, update_all_extensions = extract_arg(sys.argv, '--update-all-extensions') sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test') sys.argv, skip_python_version_check = extract_arg(sys.argv, '--skip-python-version-check') sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers') @@ -312,6 +321,9 @@ def prepare_environment(): if update_check: version_check(commit) + + if update_all_extensions: + git_pull_recursive(dir_extensions) if "--exit" in sys.argv: print("Exiting because of --exit argument") From 13081dd45ece33457f6cb2cad3a8e7840a0a6eaf Mon Sep 17 00:00:00 2001 From: vladlearns Date: Thu, 9 Mar 2023 16:56:06 +0200 Subject: [PATCH 051/104] chore: added autostash flag to pull --- launch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/launch.py b/launch.py index ba3067915..8cbf1ca5a 100644 --- a/launch.py +++ b/launch.py @@ -165,7 +165,7 @@ def git_pull_recursive(dir): for subdir, _, _ in os.walk(dir): if os.path.exists(os.path.join(subdir, '.git')): try: - output = subprocess.check_output(['git', '-C', subdir, 'pull']) + output = subprocess.check_output(['git', '-C', subdir, 'pull', '--autostash']) print(f"Pulled changes for repository in '{subdir}':\n{output.decode('utf-8').strip()}\n") except subprocess.CalledProcessError as e: print(f"Couldn't perform 'git pull' on repository in '{subdir}':\n{e.output.decode('utf-8').strip()}\n") From 37acba263389e22bc46cfffc80b2ca8b76a85287 Mon Sep 17 00:00:00 2001 From: Pam Date: Fri, 10 Mar 2023 12:19:36 +0500 Subject: [PATCH 052/104] argument to disable memory efficient for sdp --- modules/sd_hijack.py | 11 ++++++++--- modules/sd_hijack_optimizations.py | 4 ++++ modules/shared.py | 1 + 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 76cb91209..f62e9adb1 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -43,9 +43,14 @@ def apply_optimizations(): ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward optimization_method = 'xformers' elif cmd_opts.opt_sdp_attention and (hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(getattr(torch.nn.functional, "scaled_dot_product_attention"))): - print("Applying scaled dot product cross attention optimization.") - ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_attention_forward - optimization_method = 'sdp' + if cmd_opts.opt_sdp_no_mem_attention: + print("Applying scaled dot product cross attention optimization (without memory efficient attention).") + ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_no_mem_attention_forward + optimization_method = 'sdp-no-mem' + else: + print("Applying scaled dot product cross attention optimization.") + ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_attention_forward + optimization_method = 'sdp' elif cmd_opts.opt_sub_quad_attention: print("Applying sub-quadratic cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index a324a5927..68b1dd84f 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -388,6 +388,10 @@ def scaled_dot_product_attention_forward(self, x, context=None, mask=None): hidden_states = self.to_out[1](hidden_states) return hidden_states +def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None): + with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False): + return scaled_dot_product_attention_forward(self, x, context, mask) + def cross_attention_attnblock_forward(self, x): h_ = x h_ = self.norm(h_) diff --git a/modules/shared.py b/modules/shared.py index 12d0756bd..4b81c591d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -70,6 +70,7 @@ parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") parser.add_argument("--opt-sdp-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization; requires PyTorch 2.*") +parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="disables memory efficient sdp, makes image generation deterministic; requires --opt-sdp-attention") parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI") parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower) From 0981dea94832f34d638b1aa8964cfaeffd223b47 Mon Sep 17 00:00:00 2001 From: Pam Date: Fri, 10 Mar 2023 12:58:10 +0500 Subject: [PATCH 053/104] sdp refactoring --- modules/sd_hijack.py | 19 ++++++++++--------- modules/shared.py | 2 +- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index f62e9adb1..e98ae51a1 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -37,20 +37,21 @@ def apply_optimizations(): optimization_method = None + can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(getattr(torch.nn.functional, "scaled_dot_product_attention")) # not everyone has torch 2.x to use sdp + if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)): print("Applying xformers cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward optimization_method = 'xformers' - elif cmd_opts.opt_sdp_attention and (hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(getattr(torch.nn.functional, "scaled_dot_product_attention"))): - if cmd_opts.opt_sdp_no_mem_attention: - print("Applying scaled dot product cross attention optimization (without memory efficient attention).") - ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_no_mem_attention_forward - optimization_method = 'sdp-no-mem' - else: - print("Applying scaled dot product cross attention optimization.") - ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_attention_forward - optimization_method = 'sdp' + elif cmd_opts.opt_sdp_no_mem_attention and can_use_sdp: + print("Applying scaled dot product cross attention optimization (without memory efficient attention).") + ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_no_mem_attention_forward + optimization_method = 'sdp-no-mem' + elif cmd_opts.opt_sdp_attention and can_use_sdp: + print("Applying scaled dot product cross attention optimization.") + ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_attention_forward + optimization_method = 'sdp' elif cmd_opts.opt_sub_quad_attention: print("Applying sub-quadratic cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward diff --git a/modules/shared.py b/modules/shared.py index 4b81c591d..66a6bfa55 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -70,7 +70,7 @@ parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") parser.add_argument("--opt-sdp-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization; requires PyTorch 2.*") -parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="disables memory efficient sdp, makes image generation deterministic; requires --opt-sdp-attention") +parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization without memory efficient attention, makes image generation deterministic; requires PyTorch 2.*") parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI") parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower) From 1226028b9c1b153b6ceef62d8678ecb84c9d4fcd Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Fri, 10 Mar 2023 11:21:48 -0500 Subject: [PATCH 054/104] fix silly math error --- scripts/xyz_grid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 1ba954ac6..8c816a736 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -485,7 +485,7 @@ class Script(scripts.Script): zs = process_axis(z_opt, z_values) # this could be moved to common code, but unlikely to be ever triggered anywhere else - Image.MAX_IMAGE_PIXELS = opts.img_max_size_mp * 1.1 # allow 10% overhead for margins and legend + Image.MAX_IMAGE_PIXELS = opts.img_max_size_mp * 1000000 * 1.1 # allow 10% overhead for margins and legend grid_mp = round(len(xs) * len(ys) * len(zs) * p.width * p.height / 1000000) if grid_mp > opts.img_max_size_mp: return Processed(p, [], p.seed, info=f'Error: Resulting grid would be too large ({grid_mp} MPixels) (max configured size is {opts.img_max_size_mp} MPixels)') From 8d7fa2f67cb0554d8902d5d407166876020e067e Mon Sep 17 00:00:00 2001 From: Pam Date: Fri, 10 Mar 2023 22:48:41 +0500 Subject: [PATCH 055/104] sdp_attnblock_forward hijack --- modules/sd_hijack.py | 2 ++ modules/sd_hijack_optimizations.py | 24 ++++++++++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index e98ae51a1..f4bb0266f 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -47,10 +47,12 @@ def apply_optimizations(): elif cmd_opts.opt_sdp_no_mem_attention and can_use_sdp: print("Applying scaled dot product cross attention optimization (without memory efficient attention).") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_no_mem_attention_forward + ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sdp_no_mem_attnblock_forward optimization_method = 'sdp-no-mem' elif cmd_opts.opt_sdp_attention and can_use_sdp: print("Applying scaled dot product cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_attention_forward + ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sdp_attnblock_forward optimization_method = 'sdp' elif cmd_opts.opt_sub_quad_attention: print("Applying sub-quadratic cross attention optimization.") diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 68b1dd84f..2e307b5d0 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -473,6 +473,30 @@ def xformers_attnblock_forward(self, x): except NotImplementedError: return cross_attention_attnblock_forward(self, x) +def sdp_attnblock_forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + b, c, h, w = q.shape + q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v)) + dtype = q.dtype + if shared.opts.upcast_attn: + q, k = q.float(), k.float() + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + out = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False) + out = out.to(dtype) + out = rearrange(out, 'b (h w) c -> b c h w', h=h) + out = self.proj_out(out) + return x + out + +def sdp_no_mem_attnblock_forward(self, x): + with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False): + return sdp_attnblock_forward(self, x) + def sub_quad_attnblock_forward(self, x): h_ = x h_ = self.norm(h_) From 5fef67f6ee949a61826a3a043ea8610fd89fc371 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Fri, 10 Mar 2023 19:56:14 -0500 Subject: [PATCH 056/104] Requested changes --- modules/models/diffusion/uni_pc/sampler.py | 2 +- modules/sd_samplers_compvis.py | 4 +++- modules/shared.py | 3 +-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/modules/models/diffusion/uni_pc/sampler.py b/modules/models/diffusion/uni_pc/sampler.py index 708a9b2ba..6bb3bb210 100644 --- a/modules/models/diffusion/uni_pc/sampler.py +++ b/modules/models/diffusion/uni_pc/sampler.py @@ -93,7 +93,7 @@ class UniPCSampler(object): guidance_scale=unconditional_guidance_scale, ) - uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=shared.opts.uni_pc_thresholding, variant=shared.opts.uni_pc_variant, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=self.before_sample, after_sample=self.after_sample, after_update=self.after_update) + uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, variant=shared.opts.uni_pc_variant, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=self.before_sample, after_sample=self.after_sample, after_update=self.after_update) x = uni_pc.sample(img, steps=S, skip_type=shared.opts.uni_pc_skip_type, method="multistep", order=shared.opts.uni_pc_order, lower_order_final=shared.opts.uni_pc_lower_order_final) return x.to(device), None diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py index ad39ab2b3..7d07c4a5f 100644 --- a/modules/sd_samplers_compvis.py +++ b/modules/sd_samplers_compvis.py @@ -140,10 +140,12 @@ class VanillaStableDiffusionSampler: def adjust_steps_if_invalid(self, p, num_steps): if ((self.config.name == 'DDIM') and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS') or (self.config.name == 'UniPC'): + if self.config.name == 'UniPC' and num_steps < shared.opts.uni_pc_order: + num_steps = shared.opts.uni_pc_order valid_step = 999 / (1000 // num_steps) if valid_step == math.floor(valid_step): return int(valid_step) + 1 - + return num_steps def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): diff --git a/modules/shared.py b/modules/shared.py index 7c559fa42..29f8dccb1 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -485,10 +485,9 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}), 'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma"), - 'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "vary_coeff"]}), + 'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "bh2", "vary_coeff"]}), 'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}), 'uni_pc_order': OptionInfo(3, "UniPC order (must be < sampling steps)", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}), - 'uni_pc_thresholding': OptionInfo(False, "UniPC thresholding"), 'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final"), })) From f261a4a53c153c630a506bc5282e9955c36b3ef2 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 11 Mar 2023 11:56:05 +0300 Subject: [PATCH 057/104] use selected device instead of always cuda for UniPC sampler --- modules/models/diffusion/uni_pc/sampler.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/modules/models/diffusion/uni_pc/sampler.py b/modules/models/diffusion/uni_pc/sampler.py index 6bb3bb210..bf346ff48 100644 --- a/modules/models/diffusion/uni_pc/sampler.py +++ b/modules/models/diffusion/uni_pc/sampler.py @@ -3,7 +3,8 @@ import torch from .uni_pc import NoiseScheduleVP, model_wrapper, UniPC -from modules import shared +from modules import shared, devices + class UniPCSampler(object): def __init__(self, model, **kwargs): @@ -16,8 +17,8 @@ class UniPCSampler(object): def register_buffer(self, name, attr): if type(attr) == torch.Tensor: - if attr.device != torch.device("cuda"): - attr = attr.to(torch.device("cuda")) + if attr.device != devices.device: + attr = attr.to(devices.device) setattr(self, name, attr) def set_hooks(self, before_sample, after_sample, after_update): From 58b5b7c2f1d3b843803c1fc7a0aae8b1d6be5763 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 11 Mar 2023 12:09:36 +0300 Subject: [PATCH 058/104] add UniPC options to infotext --- modules/generation_parameters_copypaste.py | 8 +++++++- modules/sd_samplers_compvis.py | 14 ++++++++++++++ modules/shared.py | 9 +++++++++ 3 files changed, 30 insertions(+), 1 deletion(-) diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 89dc23bff..cb3676553 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -288,6 +288,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model settings_map = {} + + infotext_to_setting_name_mapping = [ ('Clip skip', 'CLIP_stop_at_last_layers', ), ('Conditional mask weight', 'inpainting_mask_weight'), @@ -296,7 +298,11 @@ infotext_to_setting_name_mapping = [ ('Noise multiplier', 'initial_noise_multiplier'), ('Eta', 'eta_ancestral'), ('Eta DDIM', 'eta_ddim'), - ('Discard penultimate sigma', 'always_discard_next_to_last_sigma') + ('Discard penultimate sigma', 'always_discard_next_to_last_sigma'), + ('UniPC variant', 'uni_pc_variant'), + ('UniPC skip type', 'uni_pc_skip_type'), + ('UniPC order', 'uni_pc_order'), + ('UniPC lower order final', 'uni_pc_lower_order_final'), ] diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py index 7d07c4a5f..083da18ca 100644 --- a/modules/sd_samplers_compvis.py +++ b/modules/sd_samplers_compvis.py @@ -129,6 +129,19 @@ class VanillaStableDiffusionSampler: if self.eta != 0.0: p.extra_generation_params["Eta DDIM"] = self.eta + if self.is_unipc: + keys = [ + ('UniPC variant', 'uni_pc_variant'), + ('UniPC skip type', 'uni_pc_skip_type'), + ('UniPC order', 'uni_pc_order'), + ('UniPC lower order final', 'uni_pc_lower_order_final'), + ] + + for name, key in keys: + v = getattr(shared.opts, key) + if v != shared.opts.get_default(key): + p.extra_generation_params[name] = v + for fieldname in ['p_sample_ddim', 'p_sample_plms']: if hasattr(self.sampler, fieldname): setattr(self.sampler, fieldname, self.p_sample_ddim_hook) @@ -138,6 +151,7 @@ class VanillaStableDiffusionSampler: self.mask = p.mask if hasattr(p, 'mask') else None self.nmask = p.nmask if hasattr(p, 'nmask') else None + def adjust_steps_if_invalid(self, p, num_steps): if ((self.config.name == 'DDIM') and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS') or (self.config.name == 'UniPC'): if self.config.name == 'UniPC' and num_steps < shared.opts.uni_pc_order: diff --git a/modules/shared.py b/modules/shared.py index 29f8dccb1..d481c25ba 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -563,6 +563,15 @@ class Options: return True + def get_default(self, key): + """returns the default value for the key""" + + data_label = self.data_labels.get(key) + if data_label is None: + return None + + return data_label.default + def save(self, filename): assert not cmd_opts.freeze_settings, "saving settings is disabled" From 1ace16e799c1ff43a6f67947be2506c2f83857a1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 11 Mar 2023 12:21:53 +0300 Subject: [PATCH 059/104] use path to git from env variable for git_pull_recursive --- launch.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/launch.py b/launch.py index 8cbf1ca5a..0868f8a90 100644 --- a/launch.py +++ b/launch.py @@ -161,15 +161,17 @@ def git_clone(url, dir, name, commithash=None): if commithash is not None: run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}") + def git_pull_recursive(dir): for subdir, _, _ in os.walk(dir): if os.path.exists(os.path.join(subdir, '.git')): try: - output = subprocess.check_output(['git', '-C', subdir, 'pull', '--autostash']) + output = subprocess.check_output([git, '-C', subdir, 'pull', '--autostash']) print(f"Pulled changes for repository in '{subdir}':\n{output.decode('utf-8').strip()}\n") except subprocess.CalledProcessError as e: print(f"Couldn't perform 'git pull' on repository in '{subdir}':\n{e.output.decode('utf-8').strip()}\n") + def version_check(commit): try: import requests From 3531a50080e63197752dd4d9b49f0ac34a758e12 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 11 Mar 2023 13:22:59 +0300 Subject: [PATCH 060/104] rename fields for API for saving/sending images save images to correct directories --- modules/api/api.py | 41 +++++++++++++++++------------------------ modules/api/models.py | 8 ++++---- modules/images.py | 3 --- 3 files changed, 21 insertions(+), 31 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index a6bb439c7..fbd50552b 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -178,29 +178,27 @@ class Api: def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): script, script_idx = self.get_script(txt2imgreq.script_name, scripts.scripts_txt2img) - populate = txt2imgreq.copy(update={ # Override __init__ params + populate = txt2imgreq.copy(update={ # Override __init__ params "sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index), - "do_not_save_samples": txt2imgreq.do_not_save, - "do_not_save_grid": txt2imgreq.do_not_save, - } - ) + "do_not_save_samples": not txt2imgreq.save_images, + "do_not_save_grid": not txt2imgreq.save_images, + }) if populate.sampler_name: populate.sampler_index = None # prevent a warning later on args = vars(populate) args.pop('script_name', None) - send_images = True if not 'do_not_send' in args else not args['do_not_send'] - args.pop('do_not_send', None) - args.pop('do_not_save', None) + send_images = args.pop('send_images', True) + args.pop('save_images', None) with self.queue_lock: p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args) + p.outpath_grids = opts.outdir_txt2img_grids + p.outpath_samples = opts.outdir_txt2img_samples shared.state.begin() if script is not None: - p.outpath_grids = opts.outdir_txt2img_grids - p.outpath_samples = opts.outdir_txt2img_samples p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args processed = scripts.scripts_txt2img.run(p, *p.script_args) else: @@ -222,13 +220,12 @@ class Api: if mask: mask = decode_base64_to_image(mask) - populate = img2imgreq.copy(update={ # Override __init__ params + populate = img2imgreq.copy(update={ # Override __init__ params "sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index), - "do_not_save_samples": img2imgreq.do_not_save, - "do_not_save_grid": img2imgreq.do_not_save, - "mask": mask - } - ) + "do_not_save_samples": not img2imgreq.save_images, + "do_not_save_grid": not img2imgreq.save_images, + "mask": mask, + }) if populate.sampler_name: populate.sampler_index = None # prevent a warning later on @@ -236,21 +233,17 @@ class Api: args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine. args.pop('script_name', None) - send_images = True if not 'do_not_send' in args else not args['do_not_send'] - args.pop('do_not_send', None) - args.pop('do_not_save', None) - - send_images = True if not 'do_not_send_images' in args else not args['do_not_send_images'] - args.pop('do_not_send_images', None) + send_images = args.pop('send_images', True) + args.pop('save_images', None) with self.queue_lock: p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args) p.init_images = [decode_base64_to_image(x) for x in init_images] + p.outpath_grids = opts.outdir_img2img_grids + p.outpath_samples = opts.outdir_img2img_samples shared.state.begin() if script is not None: - p.outpath_grids = opts.outdir_img2img_grids - p.outpath_samples = opts.outdir_img2img_samples p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args processed = scripts.scripts_img2img.run(p, *p.script_args) else: diff --git a/modules/api/models.py b/modules/api/models.py index 2b66e1f03..ff3fb3447 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -104,8 +104,8 @@ StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator( {"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}, - {"key": "do_not_send", "type": bool, "default": False}, - {"key": "do_not_save", "type": bool, "default": True} + {"key": "send_images", "type": bool, "default": True}, + {"key": "save_images", "type": bool, "default": False}, ] ).generate_model() @@ -120,8 +120,8 @@ StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator( {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}, - {"key": "do_not_send", "type": bool, "default": False}, - {"key": "do_not_save", "type": bool, "default": True} + {"key": "send_images", "type": bool, "default": True}, + {"key": "save_images", "type": bool, "default": False}, ] ).generate_model() diff --git a/modules/images.py b/modules/images.py index f8e62b718..5b80c23e1 100644 --- a/modules/images.py +++ b/modules/images.py @@ -489,9 +489,6 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i """ namegen = FilenameGenerator(p, seed, prompt, image) - if path is None: # set default path to avoid errors when functions are triggered manually or via api and param is not set - path = opts.outdir_save - if save_to_dirs is None: save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt) From 946797b01de3671c9969f6f7d55e35ef1adaa6e6 Mon Sep 17 00:00:00 2001 From: butaixianran Date: Sat, 11 Mar 2023 18:42:14 +0800 Subject: [PATCH 061/104] update "replace preview" link button's css modify css `.extra-network-thumbs .card:hover .additional a` 's value from `block` to `inline-block`. So, extensions can add more buttons to extra network's thumbnail card. --- style.css | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/style.css b/style.css index 05572f662..f3fb6069d 100644 --- a/style.css +++ b/style.css @@ -856,7 +856,7 @@ footer { } .extra-network-thumbs .card:hover .additional a { - display: block; + display: inline-block; } .extra-network-thumbs .actions .additional a { From aaa367e35ce4e823219c2954ca141ca1ed14800e Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 11 Mar 2023 14:18:18 +0300 Subject: [PATCH 062/104] new setting: Extra text to add before <...> when adding extra network to prompt --- javascript/extraNetworks.js | 2 +- modules/shared.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 17bf20004..5781df4ff 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -78,7 +78,7 @@ function cardClicked(tabname, textToAdd, allowNegativePrompt){ var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea") if(! tryToRemoveExtraNetworkFromPrompt(textarea, textToAdd)){ - textarea.value = textarea.value + " " + textToAdd + textarea.value = textarea.value + opts.extra_networks_add_text_separator + textToAdd } updateInput(textarea) diff --git a/modules/shared.py b/modules/shared.py index dbab00185..28d952dd4 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -442,6 +442,7 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"), options_templates.update(options_section(('extra_networks', "Extra Networks"), { "extra_networks_default_view": OptionInfo("cards", "Default view for Extra Networks", gr.Dropdown, {"choices": ["cards", "thumbs"]}), "extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + "extra_networks_add_text_separator": OptionInfo(" ", "Extra text to add before <...> when adding extra network to prompt"), "sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), })) From 7f2005127ff20170e2d92f353416c7f0705c593b Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 11 Mar 2023 14:52:29 +0300 Subject: [PATCH 063/104] rename CFGDenoiserParams fields for #8064 --- modules/script_callbacks.py | 10 +++++----- modules/sd_samplers_kdiffusion.py | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index d17031355..079118761 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -29,7 +29,7 @@ class ImageSaveParams: class CFGDenoiserParams: - def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, tensor, uncond): + def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond): self.x = x """Latent image representation in the process of being denoised""" @@ -45,11 +45,11 @@ class CFGDenoiserParams: self.total_sampling_steps = total_sampling_steps """Total number of sampling steps planned""" - self.tensor = tensor - """ Encoder hidden states of conditioning""" + self.text_cond = text_cond + """ Encoder hidden states of text conditioning from prompt""" - self.uncond = uncond - """ Encoder hidden states of unconditioning""" + self.text_uncond = text_uncond + """ Encoder hidden states of text conditioning from negative prompt""" class CFGDenoisedParams: diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index ea974be04..93f0e55a0 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -106,8 +106,8 @@ class CFGDenoiser(torch.nn.Module): x_in = denoiser_params.x image_cond_in = denoiser_params.image_cond sigma_in = denoiser_params.sigma - tensor = denoiser_params.tensor - uncond = denoiser_params.uncond + tensor = denoiser_params.text_cond + uncond = denoiser_params.text_uncond if tensor.shape[1] == uncond.shape[1]: if not is_edit_model: From d006108d75d74b3237ccbb60a9373bf28c2b85d7 Mon Sep 17 00:00:00 2001 From: Zhang Hua Date: Fri, 10 Mar 2023 16:35:55 +0800 Subject: [PATCH 064/104] webui.sh: remove all `cd` related code This may be helpful for https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/7028, because we won't change working directory to the repo now, instead, we will use any working directory. If we set working directory to a path contains repo and the custom --data-dir, the problem in this issue should be solved. Howewer, this may be treated as an incompatible change if some code assume the working directory is always the repo. Also, there may be another solution that always let --data-dir be the subdirectory of the repo, but personally I think this may not be what we actually need. As this issue mainly influent on Docker and I am not familiar with .bat files, updating webui.bat is skipped. webui.sh: source env from repo instead $PWD --- webui.sh | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/webui.sh b/webui.sh index 8cdad22d3..7b6e0568d 100755 --- a/webui.sh +++ b/webui.sh @@ -6,19 +6,18 @@ # If run from macOS, load defaults from webui-macos-env.sh if [[ "$OSTYPE" == "darwin"* ]]; then - if [[ -f webui-macos-env.sh ]] + if [[ -f "$(dirname $0)/webui-macos-env.sh" ]] then - source ./webui-macos-env.sh + source "$(dirname $0)/webui-macos-env.sh" fi fi # Read variables from webui-user.sh # shellcheck source=/dev/null -if [[ -f webui-user.sh ]] +if [[ -f "$(dirname $0)/webui-user.sh" ]] then - source ./webui-user.sh + source "$(dirname $0)/webui-user.sh" fi - # Set defaults # Install directory without trailing slash if [[ -z "${install_dir}" ]] @@ -47,12 +46,12 @@ fi # python3 venv without trailing slash (defaults to ${install_dir}/${clone_dir}/venv) if [[ -z "${venv_dir}" ]] then - venv_dir="venv" + venv_dir="${install_dir}/${clone_dir}/venv" fi if [[ -z "${LAUNCH_SCRIPT}" ]] then - LAUNCH_SCRIPT="launch.py" + LAUNCH_SCRIPT="${install_dir}/${clone_dir}/launch.py" fi # this script cannot be run as root by default @@ -140,22 +139,23 @@ then exit 1 fi -cd "${install_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/, aborting...\e[0m" "${install_dir}"; exit 1; } -if [[ -d "${clone_dir}" ]] +if [[ ! -d "${install_dir}/${clone_dir}" ]] then - cd "${clone_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/%s/, aborting...\e[0m" "${install_dir}" "${clone_dir}"; exit 1; } -else printf "\n%s\n" "${delimiter}" printf "Clone stable-diffusion-webui" printf "\n%s\n" "${delimiter}" - "${GIT}" clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git "${clone_dir}" - cd "${clone_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/%s/, aborting...\e[0m" "${install_dir}" "${clone_dir}"; exit 1; } + mkdir -p "${install_dir}" + "${GIT}" clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git "${install_dir}/${clone_dir}" fi printf "\n%s\n" "${delimiter}" printf "Create and activate python venv" printf "\n%s\n" "${delimiter}" -cd "${install_dir}"/"${clone_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/%s/, aborting...\e[0m" "${install_dir}" "${clone_dir}"; exit 1; } +# Make venv_dir absolute +if [[ "${venv_dir}" != /* ]] +then + venv_dir="${install_dir}/${clone_dir}/${venv_dir}" +fi if [[ ! -d "${venv_dir}" ]] then "${python_cmd}" -m venv "${venv_dir}" From 1fa1ab5249ccc7acaafa5f3e11c6925f91543f8b Mon Sep 17 00:00:00 2001 From: Zhang Hua Date: Fri, 10 Mar 2023 21:38:35 +0800 Subject: [PATCH 065/104] launch.py: fix failure because webui.sh's changes launch.py: using getcwd() instead curdir launch.py: use absolute path for preparing also remove chdir() launch.py: use absolute path for test launch.py: add default script_path and data_path --- launch.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/launch.py b/launch.py index 0868f8a90..e1908496a 100644 --- a/launch.py +++ b/launch.py @@ -7,6 +7,11 @@ import shlex import platform import argparse import json +try: + from modules.paths import script_path, data_path +except ModuleNotFoundError: + script_path = os.path.dirname(__file__) + data_path = os.getcwd() dir_repos = "repositories" dir_extensions = "extensions" @@ -122,7 +127,7 @@ def is_installed(package): def repo_dir(name): - return os.path.join(dir_repos, name) + return os.path.join(script_path, dir_repos, name) def run_python(code, desc=None, errdesc=None): @@ -215,7 +220,7 @@ def list_extensions(settings_file): disabled_extensions = set(settings.get('disabled_extensions', [])) - return [x for x in os.listdir(dir_extensions) if x not in disabled_extensions] + return [x for x in os.listdir(os.path.join(data_path, dir_extensions)) if x not in disabled_extensions] def run_extensions_installers(settings_file): @@ -306,7 +311,7 @@ def prepare_environment(): if not is_installed("pyngrok") and ngrok: run_pip("install pyngrok", "ngrok") - os.makedirs(dir_repos, exist_ok=True) + os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True) git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash) git_clone(taming_transformers_repo, repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash) @@ -317,7 +322,7 @@ def prepare_environment(): if not is_installed("lpips"): run_pip(f"install -r {os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}", "requirements for CodeFormer") - run_pip(f"install -r {requirements_file}", "requirements for Web UI") + run_pip(f"install -r {os.path.join(script_path, requirements_file)}", "requirements for Web UI") run_extensions_installers(settings_file=args.ui_settings_file) @@ -325,7 +330,7 @@ def prepare_environment(): version_check(commit) if update_all_extensions: - git_pull_recursive(dir_extensions) + git_pull_recursive(os.path.join(data_path, dir_extensions)) if "--exit" in sys.argv: print("Exiting because of --exit argument") @@ -341,7 +346,7 @@ def tests(test_dir): sys.argv.append("--api") if "--ckpt" not in sys.argv: sys.argv.append("--ckpt") - sys.argv.append("./test/test_files/empty.pt") + sys.argv.append(os.path.join(script_path, "test/test_files/empty.pt")) if "--skip-torch-cuda-test" not in sys.argv: sys.argv.append("--skip-torch-cuda-test") if "--disable-nan-check" not in sys.argv: @@ -350,7 +355,7 @@ def tests(test_dir): print(f"Launching Web UI in another process for testing with arguments: {' '.join(sys.argv[1:])}") os.environ['COMMANDLINE_ARGS'] = "" - with open('test/stdout.txt', "w", encoding="utf8") as stdout, open('test/stderr.txt', "w", encoding="utf8") as stderr: + with open(os.path.join(script_path, 'test/stdout.txt'), "w", encoding="utf8") as stdout, open(os.path.join(script_path, 'test/stderr.txt'), "w", encoding="utf8") as stderr: proc = subprocess.Popen([sys.executable, *sys.argv], stdout=stdout, stderr=stderr) import test.server_poll From 8106117a474a5e62dbd73687dadc6e7b7637267d Mon Sep 17 00:00:00 2001 From: Zhang Hua Date: Sat, 11 Mar 2023 09:18:08 +0800 Subject: [PATCH 066/104] models/ui.py: make the path of script.js absolute --- modules/ui.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/modules/ui.py b/modules/ui.py index 0516c6436..56b55f293 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1745,7 +1745,8 @@ def create_ui(): def reload_javascript(): - head = f'\n' + script_js = os.path.join(script_path, "script.js") + head = f'\n' inline = f"{localization.localization_js(shared.opts.localization)};" if cmd_opts.theme is not None: From 8e0d16e746759b1f9b4bf1b5abfc30f3d985415e Mon Sep 17 00:00:00 2001 From: Zhang Hua Date: Sat, 11 Mar 2023 12:22:59 +0800 Subject: [PATCH 067/104] modules/sd_vae_approx.py: fix VAE-approx path --- modules/sd_vae_approx.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/modules/sd_vae_approx.py b/modules/sd_vae_approx.py index 0027343a7..e2f004683 100644 --- a/modules/sd_vae_approx.py +++ b/modules/sd_vae_approx.py @@ -35,8 +35,11 @@ def model(): global sd_vae_approx_model if sd_vae_approx_model is None: + model_path = os.path.join(paths.models_path, "VAE-approx", "model.pt") sd_vae_approx_model = VAEApprox() - sd_vae_approx_model.load_state_dict(torch.load(os.path.join(paths.models_path, "VAE-approx", "model.pt"), map_location='cpu' if devices.device.type != 'cuda' else None)) + if not os.path.exists(model_path): + model_path = os.path.join(paths.script_path, "models", "VAE-approx", "model.pt") + sd_vae_approx_model.load_state_dict(torch.load(model_path, map_location='cpu' if devices.device.type != 'cuda' else None)) sd_vae_approx_model.eval() sd_vae_approx_model.to(devices.device, devices.dtype) From 9abe2f5e74f02d4c8e47d2d4e03464179e9f0aa2 Mon Sep 17 00:00:00 2001 From: Zhang Hua Date: Sat, 11 Mar 2023 13:27:10 +0800 Subject: [PATCH 068/104] test/server_poll.py: use absolute path for test test/server_poll.py: fix absolute path --- test/server_poll.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/server_poll.py b/test/server_poll.py index 42d56a4ca..c732630f1 100644 --- a/test/server_poll.py +++ b/test/server_poll.py @@ -1,6 +1,8 @@ import unittest import requests import time +import os +from modules.paths import script_path def run_tests(proc, test_dir): @@ -15,8 +17,8 @@ def run_tests(proc, test_dir): break if proc.poll() is None: if test_dir is None: - test_dir = "test" - suite = unittest.TestLoader().discover(test_dir, pattern="*_test.py", top_level_dir="test") + test_dir = os.path.join(script_path, "test") + suite = unittest.TestLoader().discover(test_dir, pattern="*_test.py", top_level_dir=test_dir) result = unittest.TextTestRunner(verbosity=2).run(suite) return len(result.failures) + len(result.errors) else: From d25c4b13e4be0c89401637c769e3634e7ee456a7 Mon Sep 17 00:00:00 2001 From: Zhang Hua Date: Sat, 11 Mar 2023 14:42:31 +0800 Subject: [PATCH 069/104] test/basic_features/{extras,img2img}_test.py: use absolute path --- test/basic_features/extras_test.py | 8 +++++--- test/basic_features/img2img_test.py | 8 +++++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/test/basic_features/extras_test.py b/test/basic_features/extras_test.py index 0170c511f..8ed98747d 100644 --- a/test/basic_features/extras_test.py +++ b/test/basic_features/extras_test.py @@ -1,7 +1,9 @@ +import os import unittest import requests from gradio.processing_utils import encode_pil_to_base64 from PIL import Image +from modules.paths import script_path class TestExtrasWorking(unittest.TestCase): def setUp(self): @@ -19,7 +21,7 @@ class TestExtrasWorking(unittest.TestCase): "upscaler_1": "None", "upscaler_2": "None", "extras_upscaler_2_visibility": 0, - "image": encode_pil_to_base64(Image.open(r"test/test_files/img2img_basic.png")) + "image": encode_pil_to_base64(Image.open(os.path.join(script_path, r"test/test_files/img2img_basic.png"))) } def test_simple_upscaling_performed(self): @@ -31,7 +33,7 @@ class TestPngInfoWorking(unittest.TestCase): def setUp(self): self.url_png_info = "http://localhost:7860/sdapi/v1/extra-single-image" self.png_info = { - "image": encode_pil_to_base64(Image.open(r"test/test_files/img2img_basic.png")) + "image": encode_pil_to_base64(Image.open(os.path.join(script_path, r"test/test_files/img2img_basic.png"))) } def test_png_info_performed(self): @@ -42,7 +44,7 @@ class TestInterrogateWorking(unittest.TestCase): def setUp(self): self.url_interrogate = "http://localhost:7860/sdapi/v1/extra-single-image" self.interrogate = { - "image": encode_pil_to_base64(Image.open(r"test/test_files/img2img_basic.png")), + "image": encode_pil_to_base64(Image.open(os.path.join(script_path, r"test/test_files/img2img_basic.png"))), "model": "clip" } diff --git a/test/basic_features/img2img_test.py b/test/basic_features/img2img_test.py index 08c5c903e..5240ec36a 100644 --- a/test/basic_features/img2img_test.py +++ b/test/basic_features/img2img_test.py @@ -1,14 +1,16 @@ +import os import unittest import requests from gradio.processing_utils import encode_pil_to_base64 from PIL import Image +from modules.paths import script_path class TestImg2ImgWorking(unittest.TestCase): def setUp(self): self.url_img2img = "http://localhost:7860/sdapi/v1/img2img" self.simple_img2img = { - "init_images": [encode_pil_to_base64(Image.open(r"test/test_files/img2img_basic.png"))], + "init_images": [encode_pil_to_base64(Image.open(os.path.join(script_path, r"test/test_files/img2img_basic.png")))], "resize_mode": 0, "denoising_strength": 0.75, "mask": None, @@ -47,11 +49,11 @@ class TestImg2ImgWorking(unittest.TestCase): self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200) def test_inpainting_masked_performed(self): - self.simple_img2img["mask"] = encode_pil_to_base64(Image.open(r"test/test_files/mask_basic.png")) + self.simple_img2img["mask"] = encode_pil_to_base64(Image.open(os.path.join(script_path, r"test/test_files/img2img_basic.png"))) self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200) def test_inpainting_with_inverted_masked_performed(self): - self.simple_img2img["mask"] = encode_pil_to_base64(Image.open(r"test/test_files/mask_basic.png")) + self.simple_img2img["mask"] = encode_pil_to_base64(Image.open(os.path.join(script_path, r"test/test_files/img2img_basic.png"))) self.simple_img2img["inpainting_mask_invert"] = True self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200) From f36ba9949a64fd35a81369e4ec7107b1b87ef7fb Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 11 Mar 2023 15:02:50 +0300 Subject: [PATCH 070/104] add credit for UniPC sampler into the readme --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 2ceb4d2db..24f8e7998 100644 --- a/README.md +++ b/README.md @@ -157,5 +157,6 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al - Sampling in float32 precision from a float16 UNet - marunine for the idea, Birch-san for the example Diffusers implementation (https://github.com/Birch-san/diffusers-play/tree/92feee6) - Instruct pix2pix - Tim Brooks (star), Aleksander Holynski (star), Alexei A. Efros (no star) - https://github.com/timothybrooks/instruct-pix2pix - Security advice - RyotaK +- UniPC sampler - Wenliang Zhao - https://github.com/wl-zhao/UniPC - Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user. - (You) From ce68ab8d0dfdd25e820c58dbc9d3b0148a2022a4 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 11 Mar 2023 15:27:42 +0300 Subject: [PATCH 071/104] remove underscores from function names in #8366 remove LRU from #8366 because I don't know why it's there --- extensions-builtin/Lora/ui_extra_networks_lora.py | 4 ++-- modules/ui_extra_networks.py | 8 +++----- modules/ui_extra_networks_checkpoints.py | 4 ++-- modules/ui_extra_networks_hypernets.py | 4 ++-- modules/ui_extra_networks_textual_inversion.py | 4 ++-- 5 files changed, 11 insertions(+), 13 deletions(-) diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index 9da13a090..6815f6ef8 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -18,8 +18,8 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): yield { "name": name, "filename": path, - "preview": self._find_preview(path), - "description": self._find_description(path), + "preview": self.find_preview(path), + "description": self.find_description(path), "search_term": self.search_terms_from_path(lora_on_disk.filename), "prompt": json.dumps(f""), "local_preview": path + ".png", diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index cd61a5694..21c52287b 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -1,9 +1,7 @@ import glob import os.path import urllib.parse -from functools import lru_cache from pathlib import Path -from typing import Optional from modules import shared import gradio as gr @@ -140,17 +138,17 @@ class ExtraNetworksPage: return self.card_page.format(**args) - def _find_preview(self, path: str) -> Optional[str]: + def find_preview(self, path): """ Find a preview PNG for a given path (without extension) and call link_preview on it. """ for file in [path + ".png", path + ".preview.png"]: if os.path.isfile(file): return self.link_preview(file) + return None - @lru_cache(maxsize=512) - def _find_description(self, path: str) -> Optional[str]: + def find_description(self, path): """ Find and read a description file for a given path (without extension). """ diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py index 1deb785aa..7d1aa2030 100644 --- a/modules/ui_extra_networks_checkpoints.py +++ b/modules/ui_extra_networks_checkpoints.py @@ -19,8 +19,8 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): yield { "name": checkpoint.name_for_extra, "filename": path, - "preview": self._find_preview(path), - "description": self._find_description(path), + "preview": self.find_preview(path), + "description": self.find_description(path), "search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""), "onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"', "local_preview": path + ".png", diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py index 80cc2a248..8e49e93b0 100644 --- a/modules/ui_extra_networks_hypernets.py +++ b/modules/ui_extra_networks_hypernets.py @@ -18,8 +18,8 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): yield { "name": name, "filename": path, - "preview": self._find_preview(path), - "description": self._find_description(path), + "preview": self.find_preview(path), + "description": self.find_description(path), "search_term": self.search_terms_from_path(path), "prompt": json.dumps(f""), "local_preview": path + ".png", diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py index f3bae6665..1ad806ebf 100644 --- a/modules/ui_extra_networks_textual_inversion.py +++ b/modules/ui_extra_networks_textual_inversion.py @@ -18,8 +18,8 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): yield { "name": embedding.name, "filename": embedding.filename, - "preview": self._find_preview(path), - "description": self._find_description(path), + "preview": self.find_preview(path), + "description": self.find_description(path), "search_term": self.search_terms_from_path(embedding.filename), "prompt": json.dumps(embedding.name), "local_preview": path + ".preview.png", From 9320139bd8340e8b12c178fe80411c8e25f78d9e Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 11 Mar 2023 15:33:24 +0300 Subject: [PATCH 072/104] support three extensions for preview instead of one: png, jpg, webp --- modules/ui_extra_networks.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 21c52287b..3b476f3a0 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -142,7 +142,11 @@ class ExtraNetworksPage: """ Find a preview PNG for a given path (without extension) and call link_preview on it. """ - for file in [path + ".png", path + ".preview.png"]: + + preview_extensions = ["png", "jpg", "webp"] + potential_files = sum([[path + "." + ext, path + ".preview." + ext] for ext in preview_extensions], []) + + for file in potential_files: if os.path.isfile(file): return self.link_preview(file) From 6da2027213c3bf132c54489d34c48ec084f8dc11 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 11 Mar 2023 15:46:20 +0300 Subject: [PATCH 073/104] save previews for extra networks in the selected format --- extensions-builtin/Lora/ui_extra_networks_lora.py | 2 +- modules/ui_extra_networks.py | 3 +++ modules/ui_extra_networks_checkpoints.py | 2 +- modules/ui_extra_networks_hypernets.py | 2 +- modules/ui_extra_networks_textual_inversion.py | 4 ++-- 5 files changed, 8 insertions(+), 5 deletions(-) diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index 6815f6ef8..8d32052ec 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -22,7 +22,7 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): "description": self.find_description(path), "search_term": self.search_terms_from_path(lora_on_disk.filename), "prompt": json.dumps(f""), - "local_preview": path + ".png", + "local_preview": f"{path}.{shared.opts.samples_format}", } def allowed_directories_for_previews(self): diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 3b476f3a0..85f0af4c2 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -144,6 +144,9 @@ class ExtraNetworksPage: """ preview_extensions = ["png", "jpg", "webp"] + if shared.opts.samples_format not in preview_extensions: + preview_extensions.append(shared.opts.samples_format) + potential_files = sum([[path + "." + ext, path + ".preview." + ext] for ext in preview_extensions], []) for file in potential_files: diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py index 7d1aa2030..a17aa9c9c 100644 --- a/modules/ui_extra_networks_checkpoints.py +++ b/modules/ui_extra_networks_checkpoints.py @@ -23,7 +23,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): "description": self.find_description(path), "search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""), "onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"', - "local_preview": path + ".png", + "local_preview": f"{path}.{shared.opts.samples_format}", } def allowed_directories_for_previews(self): diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py index 8e49e93b0..6187e0007 100644 --- a/modules/ui_extra_networks_hypernets.py +++ b/modules/ui_extra_networks_hypernets.py @@ -22,7 +22,7 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): "description": self.find_description(path), "search_term": self.search_terms_from_path(path), "prompt": json.dumps(f""), - "local_preview": path + ".png", + "local_preview": f"{path}.preview.{shared.opts.samples_format}", } def allowed_directories_for_previews(self): diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py index 1ad806ebf..6944d5593 100644 --- a/modules/ui_extra_networks_textual_inversion.py +++ b/modules/ui_extra_networks_textual_inversion.py @@ -1,7 +1,7 @@ import json import os -from modules import ui_extra_networks, sd_hijack +from modules import ui_extra_networks, sd_hijack, shared class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): @@ -22,7 +22,7 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): "description": self.find_description(path), "search_term": self.search_terms_from_path(embedding.filename), "prompt": json.dumps(embedding.name), - "local_preview": path + ".preview.png", + "local_preview": f"{path}.preview.{shared.opts.samples_format}", } def allowed_directories_for_previews(self): From 52dcf0f0c70f1edc4a04ef7bc905528fbc6cdbec Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 11 Mar 2023 16:27:58 +0300 Subject: [PATCH 074/104] record startup time --- webui.py | 47 +++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 43 insertions(+), 4 deletions(-) diff --git a/webui.py b/webui.py index be39fa8dc..325618772 100644 --- a/webui.py +++ b/webui.py @@ -12,11 +12,22 @@ from packaging import version import logging logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) -from modules import import_hook, errors, extra_networks, ui_extra_networks_checkpoints -from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion -from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call +from modules import paths, timer, import_hook, errors + +startup_timer = timer.Timer() import torch +startup_timer.record("import torch") + +import gradio +startup_timer.record("import gradio") + +import ldm.modules.encoders.modules +startup_timer.record("import ldm") + +from modules import extra_networks, ui_extra_networks_checkpoints +from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion +from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call # Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors if ".dev" in torch.__version__ or "+git" in torch.__version__: @@ -30,7 +41,6 @@ import modules.gfpgan_model as gfpgan import modules.img2img import modules.lowvram -import modules.paths import modules.scripts import modules.sd_hijack import modules.sd_models @@ -45,6 +55,8 @@ from modules import modelloader from modules.shared import cmd_opts import modules.hypernetworks.hypernetwork +startup_timer.record("other imports") + if cmd_opts.server_name: server_name = cmd_opts.server_name @@ -88,6 +100,7 @@ def initialize(): extensions.list_extensions() localization.list_localizations(cmd_opts.localizations_dir) + startup_timer.record("list extensions") if cmd_opts.ui_debug_mode: shared.sd_upscalers = upscaler.UpscalerLanczos().scalers @@ -96,16 +109,28 @@ def initialize(): modelloader.cleanup_models() modules.sd_models.setup_model() + startup_timer.record("list SD models") + codeformer.setup_model(cmd_opts.codeformer_models_path) + startup_timer.record("setup codeformer") + gfpgan.setup_model(cmd_opts.gfpgan_models_path) + startup_timer.record("setup gfpgan") modelloader.list_builtin_upscalers() + startup_timer.record("list builtin upscalers") + modules.scripts.load_scripts() + startup_timer.record("load scripts") + modelloader.load_upscalers() + startup_timer.record("load upscalers") modules.sd_vae.refresh_vae_list() + startup_timer.record("refresh VAE") modules.textual_inversion.textual_inversion.list_textual_inversion_templates() + startup_timer.record("refresh textual inversion templates") try: modules.sd_models.load_model() @@ -114,6 +139,7 @@ def initialize(): print("", file=sys.stderr) print("Stable diffusion model failed to load, exiting", file=sys.stderr) exit(1) + startup_timer.record("load SD checkpoint") shared.opts.data["sd_model_checkpoint"] = shared.sd_model.sd_checkpoint_info.title @@ -121,8 +147,10 @@ def initialize(): shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed) + startup_timer.record("opts onchange") shared.reload_hypernetworks() + startup_timer.record("reload hypernets") ui_extra_networks.intialize() ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion()) @@ -131,6 +159,7 @@ def initialize(): extra_networks.initialize() extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet()) + startup_timer.record("extra networks") if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None: @@ -144,6 +173,7 @@ def initialize(): print("TLS setup invalid, running webui without TLS") else: print("Running with TLS") + startup_timer.record("TLS") # make the program just exit at ctrl+c without waiting for anything def sigint_handler(sig, frame): @@ -189,6 +219,7 @@ def api_only(): modules.script_callbacks.app_started_callback(None, app) + print(f"Startup time: {startup_timer.summary()}.") api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861) @@ -199,10 +230,13 @@ def webui(): while 1: if shared.opts.clean_temp_dir_at_start: ui_tempdir.cleanup_tmpdr() + startup_timer.record("cleanup temp dir") modules.script_callbacks.before_ui_callback() + startup_timer.record("scripts before_ui_callback") shared.demo = modules.ui.create_ui() + startup_timer.record("create ui") if cmd_opts.gradio_queue: shared.demo.queue(64) @@ -229,6 +263,8 @@ def webui(): # after initial launch, disable --autolaunch for subsequent restarts cmd_opts.autolaunch = False + startup_timer.record("gradio launch") + # gradio uses a very open CORS policy via app.user_middleware, which makes it possible for # an attacker to trick the user into opening a malicious HTML page, which makes a request to the # running web ui and do whatever the attacker wants, including installing an extension and @@ -247,6 +283,9 @@ def webui(): ui_extra_networks.add_pages_to_demo(app) modules.script_callbacks.app_started_callback(shared.demo, app) + startup_timer.record("scripts app_started_callback") + + print(f"Startup time: {startup_timer.summary()}.") wait_on_server(shared.demo) print('Restarting UI...') From a47c18297e1611568c732e6e6922d5be9def7c47 Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Sat, 11 Mar 2023 08:33:55 -0500 Subject: [PATCH 075/104] use assert instead of return --- scripts/xyz_grid.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 8c816a736..44f7374cb 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -485,10 +485,9 @@ class Script(scripts.Script): zs = process_axis(z_opt, z_values) # this could be moved to common code, but unlikely to be ever triggered anywhere else - Image.MAX_IMAGE_PIXELS = opts.img_max_size_mp * 1000000 * 1.1 # allow 10% overhead for margins and legend + Image.MAX_IMAGE_PIXELS = opts.img_max_size_mp * 1.1 # allow 10% overhead for margins and legend grid_mp = round(len(xs) * len(ys) * len(zs) * p.width * p.height / 1000000) - if grid_mp > opts.img_max_size_mp: - return Processed(p, [], p.seed, info=f'Error: Resulting grid would be too large ({grid_mp} MPixels) (max configured size is {opts.img_max_size_mp} MPixels)') + assert grid_mp < opts.img_max_size_mp, f'Error: Resulting grid would be too large ({grid_mp} MPixels) (max configured size is {opts.img_max_size_mp} MPixels)' def fix_axis_seeds(axis_opt, axis_list): if axis_opt.label in ['Seed', 'Var. seed']: From 1e1a32b130ea8088aeda976cd044544c33a659c3 Mon Sep 17 00:00:00 2001 From: Adam Huganir Date: Sat, 11 Mar 2023 09:34:17 -0500 Subject: [PATCH 076/104] Update requirements_versions.txt revert back to .27 --- requirements_versions.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements_versions.txt b/requirements_versions.txt index 41e0ccc53..331d0fe86 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -23,7 +23,7 @@ torchdiffeq==0.2.3 kornia==0.6.7 lark==1.1.2 inflection==0.5.1 -GitPython==3.1.30 +GitPython==3.1.27 torchsde==0.2.5 safetensors==0.2.7 httpcore<=0.15 From 5cea278d3ad3a1c1a9b454387241a29bec11699b Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 11 Mar 2023 17:51:55 +0300 Subject: [PATCH 077/104] bump GitPython to 3.1.30 because some people would be upset about it being below that version #8118 --- requirements_versions.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements_versions.txt b/requirements_versions.txt index 331d0fe86..41e0ccc53 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -23,7 +23,7 @@ torchdiffeq==0.2.3 kornia==0.6.7 lark==1.1.2 inflection==0.5.1 -GitPython==3.1.27 +GitPython==3.1.30 torchsde==0.2.5 safetensors==0.2.7 httpcore<=0.15 From 7fd19fa4e7039746dc98990883acfa500b90b6c7 Mon Sep 17 00:00:00 2001 From: "Alex \"mcmonkey\" Goodwin" Date: Sat, 11 Mar 2023 07:22:22 -0800 Subject: [PATCH 078/104] initial fix for filename length limits on *nix systems --- modules/images.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/modules/images.py b/modules/images.py index 7df2b08c7..4c204fcaa 100644 --- a/modules/images.py +++ b/modules/images.py @@ -512,6 +512,9 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i file_decoration = "-" + file_decoration file_decoration = namegen.apply(file_decoration) + suffix + if hasattr(os, 'statvfs'): + max_name_len = os.statvfs(path).f_namemax + file_decoration = file_decoration[:max_name_len - 5] if add_number: basecount = get_next_sequence_number(path, basename) From 94ffa9fc5386e51f20692ab46906135e8de33110 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 11 Mar 2023 18:55:48 +0300 Subject: [PATCH 079/104] emergency fix for xyz plot --- scripts/xyz_grid.py | 1 - 1 file changed, 1 deletion(-) diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 84c73e28b..9a0678fa7 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -500,7 +500,6 @@ class Script(scripts.Script): zs = process_axis(z_opt, z_values) # this could be moved to common code, but unlikely to be ever triggered anywhere else - Image.MAX_IMAGE_PIXELS = opts.img_max_size_mp * 1.1 # allow 10% overhead for margins and legend grid_mp = round(len(xs) * len(ys) * len(zs) * p.width * p.height / 1000000) assert grid_mp < opts.img_max_size_mp, f'Error: Resulting grid would be too large ({grid_mp} MPixels) (max configured size is {opts.img_max_size_mp} MPixels)' From 29ce0bf4f2e708cbd58ee4b9c89d6f27c2f36baa Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Sat, 11 Mar 2023 12:01:08 -0500 Subject: [PATCH 080/104] allow usage of latest fastapi --- requirements_versions.txt | 2 +- webui.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/requirements_versions.txt b/requirements_versions.txt index 41e0ccc53..0031c6162 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -27,4 +27,4 @@ GitPython==3.1.30 torchsde==0.2.5 safetensors==0.2.7 httpcore<=0.15 -fastapi==0.90.1 +fastapi==0.94.0 diff --git a/webui.py b/webui.py index 325618772..1a4175afb 100644 --- a/webui.py +++ b/webui.py @@ -183,13 +183,16 @@ def initialize(): signal.signal(signal.SIGINT, sigint_handler) -def setup_cors(app): +def setup_middleware(app): + app.middleware_stack = None # reset current middleware to allow modifying user provided list + app.add_middleware(GZipMiddleware, minimum_size=1000) if cmd_opts.cors_allow_origins and cmd_opts.cors_allow_origins_regex: app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*']) elif cmd_opts.cors_allow_origins: app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*'], allow_credentials=True, allow_headers=['*']) elif cmd_opts.cors_allow_origins_regex: app.add_middleware(CORSMiddleware, allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*']) + app.build_middleware_stack() # rebuild middleware stack on-the-fly def create_api(app): @@ -213,8 +216,7 @@ def api_only(): initialize() app = FastAPI() - setup_cors(app) - app.add_middleware(GZipMiddleware, minimum_size=1000) + setup_middleware(app) api = create_api(app) modules.script_callbacks.app_started_callback(None, app) @@ -271,9 +273,7 @@ def webui(): # running its code. We disable this here. Suggested by RyotaK. app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware'] - setup_cors(app) - - app.add_middleware(GZipMiddleware, minimum_size=1000) + setup_middleware(app) modules.progress.setup_progress_api(app) From 2174f58daee1e077eec1125e196d34cc93dbaf23 Mon Sep 17 00:00:00 2001 From: Vespinian Date: Sat, 11 Mar 2023 12:21:33 -0500 Subject: [PATCH 081/104] Changed alwayson_script_name and alwayson_script_args api params to 1 alwayson_scripts param dict --- modules/api/api.py | 23 +++++++---------------- modules/api/models.py | 4 ++-- 2 files changed, 9 insertions(+), 18 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index 248922d29..8a17017bc 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -195,22 +195,17 @@ class Api: script_args[0] = 0 # Now check for always on scripts - if request.alwayson_script_name and (len(request.alwayson_script_name) > 0): - # always on script with no arg should always run, but if you include their name in the api request, send an empty list for there args - if not request.alwayson_script_args: - raise HTTPException(status_code=422, detail=f"Script {request.alwayson_script_name} has no arg list") - if len(request.alwayson_script_name) != len(request.alwayson_script_args): - raise HTTPException(status_code=422, detail=f"Number of script names and number of script arg lists doesn't match") - - for alwayson_script_name, alwayson_script_args in zip(request.alwayson_script_name, request.alwayson_script_args): + if request.alwayson_scripts and (len(request.alwayson_scripts) > 0): + for alwayson_script_name in request.alwayson_scripts.keys(): alwayson_script = self.get_script(alwayson_script_name, script_runner) if alwayson_script == None: raise HTTPException(status_code=422, detail=f"always on script {alwayson_script_name} not found") # Selectable script in always on script param check if alwayson_script.alwayson == False: raise HTTPException(status_code=422, detail=f"Cannot have a selectable script in the always on scripts params") - if alwayson_script_args != []: - script_args[alwayson_script.args_from:alwayson_script.args_to] = alwayson_script_args + # always on script with no arg should always run so you don't really need to add them to the requests + if "args" in request.alwayson_scripts[alwayson_script_name]: + script_args[alwayson_script.args_from:alwayson_script.args_to] = request.alwayson_scripts[alwayson_script_name]["args"] return script_args def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): @@ -226,15 +221,13 @@ class Api: "do_not_save_grid": True } ) - if populate.sampler_name: populate.sampler_index = None # prevent a warning later on args = vars(populate) args.pop('script_name', None) args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them - args.pop('alwayson_script_name', None) - args.pop('alwayson_script_args', None) + args.pop('alwayson_scripts', None) script_args = self.init_script_args(txt2imgreq, selectable_scripts, selectable_script_idx, script_runner) @@ -279,7 +272,6 @@ class Api: "mask": mask } ) - if populate.sampler_name: populate.sampler_index = None # prevent a warning later on @@ -287,8 +279,7 @@ class Api: args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine. args.pop('script_name', None) args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them - args.pop('alwayson_script_name', None) - args.pop('alwayson_script_args', None) + args.pop('alwayson_scripts', None) script_args = self.init_script_args(img2imgreq, selectable_scripts, selectable_script_idx, script_runner) diff --git a/modules/api/models.py b/modules/api/models.py index 86c701780..e273469d2 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -100,13 +100,13 @@ class PydanticModelGenerator: StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator( "StableDiffusionProcessingTxt2Img", StableDiffusionProcessingTxt2Img, - [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}, {"key": "alwayson_script_name", "type": list, "default": []}, {"key": "alwayson_script_args", "type": list, "default": []}] + [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}, {"key": "alwayson_scripts", "type": dict, "default": {}}] ).generate_model() StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator( "StableDiffusionProcessingImg2Img", StableDiffusionProcessingImg2Img, - [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}, {"key": "alwayson_script_name", "type": list, "default": []}, {"key": "alwayson_script_args", "type": list, "default": []}] + [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}, {"key": "alwayson_scripts", "type": dict, "default": {}}] ).generate_model() class TextToImageResponse(BaseModel): From 5546e71a105033989adacc2df9dfb53f81a0534c Mon Sep 17 00:00:00 2001 From: Vespinian Date: Sat, 11 Mar 2023 12:35:20 -0500 Subject: [PATCH 082/104] Fixed whitespace --- modules/api/models.py | 1 - 1 file changed, 1 deletion(-) diff --git a/modules/api/models.py b/modules/api/models.py index 6c1c55467..4a70f440c 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -113,7 +113,6 @@ StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator( StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator( "StableDiffusionProcessingImg2Img", StableDiffusionProcessingImg2Img, - [ {"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, From 27e319dc4f09a2f040043948e5c52965976f8491 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 11 Mar 2023 21:22:52 +0300 Subject: [PATCH 083/104] alternative solution for #8089 --- modules/shared.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/modules/shared.py b/modules/shared.py index 4e229353b..2fb9e3b5a 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -116,7 +116,10 @@ parser.add_argument("--no-download-sd-model", action='store_true', help="don't d script_loading.preload_extensions(extensions.extensions_dir, parser) script_loading.preload_extensions(extensions.extensions_builtin_dir, parser) -cmd_opts = parser.parse_args() +if os.environ.get('IGNORE_CMD_ARGS_ERRORS', None) is None: + cmd_opts = parser.parse_args() +else: + cmd_opts, _ = parser.parse_known_args() restricted_opts = { "samples_filename_pattern", From 247a34498b337798a371d69483bbcab49b5c320c Mon Sep 17 00:00:00 2001 From: Kilvoctu Date: Sat, 11 Mar 2023 13:11:26 -0600 Subject: [PATCH 084/104] restore text, remove 'close' don't use emojis for extra network buttons; remove 'close' --- javascript/extraNetworks.js | 2 -- modules/ui_extra_networks.py | 8 +------- 2 files changed, 1 insertion(+), 9 deletions(-) diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 17bf20004..8f7cee035 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -5,12 +5,10 @@ function setupExtraNetworksForTab(tabname){ var tabs = gradioApp().querySelector('#'+tabname+'_extra_tabs > div') var search = gradioApp().querySelector('#'+tabname+'_extra_search textarea') var refresh = gradioApp().getElementById(tabname+'_extra_refresh') - var close = gradioApp().getElementById(tabname+'_extra_close') search.classList.add('search') tabs.appendChild(search) tabs.appendChild(refresh) - tabs.appendChild(close) search.addEventListener("input", function(evt){ searchTerm = search.value.toLowerCase() diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 8786fde6b..3f56bf635 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -13,10 +13,6 @@ from modules.generation_parameters_copypaste import image_from_url_text extra_pages = [] allowed_dirs = set() -# Using constants for these since the variation selector isn't visible. -# Important that they exactly match script.js for tooltip to work. -refresh_symbol = '\U0001f504' # 🔄 -close_symbol = '\U0000274C' # ❌ def register_page(page): """registers extra networks page for the UI; recommend doing it in on_before_ui() callback for extensions""" @@ -186,8 +182,7 @@ def create_ui(container, button, tabname): ui.pages.append(page_elem) filter = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False) - button_refresh = gr.Button(refresh_symbol, elem_id=tabname+"_extra_refresh") - button_close = gr.Button(close_symbol, elem_id=tabname+"_extra_close") + button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh") ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False) @@ -198,7 +193,6 @@ def create_ui(container, button, tabname): state_visible = gr.State(value=False) button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container]) - button_close.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container]) def refresh(): res = [] From 49bbdbe4476a7325e184f3022166c3dc02d086ba Mon Sep 17 00:00:00 2001 From: Vespinian Date: Sat, 11 Mar 2023 14:34:56 -0500 Subject: [PATCH 085/104] small diff whitespace cleanup --- modules/api/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/api/api.py b/modules/api/api.py index abdbb6a7a..35e17afc9 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -274,7 +274,7 @@ class Api: ui.create_ui() selectable_scripts, selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner) - populate = img2imgreq.copy(update={ # Override __init__ params + populate = img2imgreq.copy(update={ # Override __init__ params "sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index), "do_not_save_samples": not img2imgreq.save_images, "do_not_save_grid": not img2imgreq.save_images, From 48f4abd2e61e545104f72eb50c9ab9b100726948 Mon Sep 17 00:00:00 2001 From: EllangoK Date: Sat, 11 Mar 2023 15:52:14 -0500 Subject: [PATCH 086/104] fix dims typo in unipc --- modules/models/diffusion/uni_pc/uni_pc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/models/diffusion/uni_pc/uni_pc.py b/modules/models/diffusion/uni_pc/uni_pc.py index df63d1bcf..e9a093a2b 100644 --- a/modules/models/diffusion/uni_pc/uni_pc.py +++ b/modules/models/diffusion/uni_pc/uni_pc.py @@ -719,7 +719,7 @@ class UniPC: x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t) else: x_t_ = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dimss) * x + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x - expand_dims(sigma_t * h_phi_1, dims) * model_prev_0 ) if x_t is None: From a4cb96d4ae82741be9f0d072a37af3ae39521379 Mon Sep 17 00:00:00 2001 From: brkirch Date: Sat, 11 Mar 2023 17:35:17 -0500 Subject: [PATCH 087/104] Remove test, use bool tensor fix by default The test isn't working correctly on macOS 13.3 and the bool tensor fix for cumsum is currently always needed anyway, so enable the fix by default. --- modules/mac_specific.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/modules/mac_specific.py b/modules/mac_specific.py index ddcea53b9..18e6ff720 100644 --- a/modules/mac_specific.py +++ b/modules/mac_specific.py @@ -23,7 +23,7 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs): output_dtype = kwargs.get('dtype', input.dtype) if output_dtype == torch.int64: return cumsum_func(input.cpu(), *args, **kwargs).to(input.device) - elif cumsum_needs_bool_fix and output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16): + elif output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16): return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64) return cumsum_func(input, *args, **kwargs) @@ -45,7 +45,6 @@ if has_mps: CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad) elif version.parse(torch.__version__) > version.parse("1.13.1"): cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0)) - cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0)) cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs) CondFunc('torch.cumsum', cumsum_fix_func, None) CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None) From 5ed5e95fb8a0a4a3292eff22dd1b25e960b066a9 Mon Sep 17 00:00:00 2001 From: high_byte Date: Sun, 12 Mar 2023 03:29:07 +0200 Subject: [PATCH 088/104] add face restoration option to xyz_grid --- scripts/xyz_grid.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 9a0678fa7..ce584981b 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -132,6 +132,20 @@ def apply_uni_pc_order(p, x, xs): opts.data["uni_pc_order"] = min(x, p.steps - 1) +def apply_face_restore(p, opt, x): + opt = opt.lower() + if opt == 'codeformer': + is_active = True + p.face_restoration_model = 'CodeFormer' + elif opt == 'gfpgan': + is_active = True + p.face_restoration_model = 'GFPGAN' + else: + is_active = opt in ('true', 'yes', 'y', '1') + + p.restore_faces = is_active + + def format_value_add_label(p, opt, x): if type(x) == float: x = round(x, 8) @@ -210,6 +224,7 @@ axis_options = [ AxisOption("VAE", str, apply_vae, cost=0.7, choices=lambda: list(sd_vae.vae_dict)), AxisOption("Styles", str, apply_styles, choices=lambda: list(shared.prompt_styles.styles)), AxisOption("UniPC Order", int, apply_uni_pc_order, cost=0.5), + AxisOption("Face restore", str, apply_face_restore, format_value=format_value), ] From db85421da18cfaaf1ed0361bc2a9ee40b5796344 Mon Sep 17 00:00:00 2001 From: bluelovers Date: Mon, 27 Feb 2023 09:52:55 +0800 Subject: [PATCH 089/104] feat: better lightbox when not enable zoom --- javascript/imageviewer.js | 2 +- style.css | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/javascript/imageviewer.js b/javascript/imageviewer.js index aac2ee823..28e748b74 100644 --- a/javascript/imageviewer.js +++ b/javascript/imageviewer.js @@ -11,7 +11,7 @@ function showModal(event) { if (modalImage.style.display === 'none') { lb.style.setProperty('background-image', 'url(' + source.src + ')'); } - lb.style.display = "block"; + lb.style.display = "flex"; lb.focus() const tabTxt2Img = gradioApp().getElementById("tab_txt2img") diff --git a/style.css b/style.css index 05572f662..4695cc299 100644 --- a/style.css +++ b/style.css @@ -436,9 +436,7 @@ input[type="range"]{ #modalImage { display: block; - margin-left: auto; - margin-right: auto; - margin-top: auto; + margin: auto; width: auto; } From 5c9f2bbb7473c7085dc961bbf81d5248a4859e90 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 12 Mar 2023 08:58:58 +0300 Subject: [PATCH 090/104] do not import modules.paths in launch.py --- launch.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/launch.py b/launch.py index e1908496a..4873da606 100644 --- a/launch.py +++ b/launch.py @@ -7,11 +7,14 @@ import shlex import platform import argparse import json -try: - from modules.paths import script_path, data_path -except ModuleNotFoundError: - script_path = os.path.dirname(__file__) - data_path = os.getcwd() + +parser = argparse.ArgumentParser(add_help=False) +parser.add_argument("--ui-settings-file", type=str, default='config.json') +parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.realpath(__file__))) +args, _ = parser.parse_known_args(sys.argv) + +script_path = os.path.dirname(__file__) +data_path = os.getcwd() dir_repos = "repositories" dir_extensions = "extensions" @@ -257,10 +260,6 @@ def prepare_environment(): sys.argv += shlex.split(commandline_args) - parser = argparse.ArgumentParser(add_help=False) - parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default='config.json') - args, _ = parser.parse_known_args(sys.argv) - sys.argv, _ = extract_arg(sys.argv, '-f') sys.argv, update_all_extensions = extract_arg(sys.argv, '--update-all-extensions') sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test') From 3c922d983bf60ba187b5422b3690e6b7fb07777e Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 12 Mar 2023 12:11:51 +0300 Subject: [PATCH 091/104] fix #8492 breaking the program when the directory with code contains spaces. --- launch.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/launch.py b/launch.py index 4873da606..b943fed22 100644 --- a/launch.py +++ b/launch.py @@ -319,9 +319,11 @@ def prepare_environment(): git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash) if not is_installed("lpips"): - run_pip(f"install -r {os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}", "requirements for CodeFormer") + run_pip(f"install -r \"{os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}\"", "requirements for CodeFormer") - run_pip(f"install -r {os.path.join(script_path, requirements_file)}", "requirements for Web UI") + if not os.path.isfile(requirements_file): + requirements_file = os.path.join(script_path, requirements_file) + run_pip(f"install -r \"{requirements_file}\"", "requirements for Web UI") run_extensions_installers(settings_file=args.ui_settings_file) From bd67c41f5415c4c42d2383f128fe9ee656153bd0 Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Sun, 12 Mar 2023 09:19:23 -0400 Subject: [PATCH 092/104] force refresh tqdm before close --- modules/shared.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/shared.py b/modules/shared.py index 2fb9e3b5a..f28a12ccc 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -714,6 +714,7 @@ class TotalTQDM: def clear(self): if self._tqdm is not None: + self._tqdm.refresh() self._tqdm.close() self._tqdm = None From 27eedb696661d031b9a7b8641b50eaec8dabf64f Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 12 Mar 2023 17:20:04 +0300 Subject: [PATCH 093/104] change extension index link to the new dedicated repo instead of wiki --- modules/ui_extensions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index bd4308ef0..df75a925d 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -304,7 +304,7 @@ def create_ui(): with gr.TabItem("Available"): with gr.Row(): refresh_available_extensions_button = gr.Button(value="Load from:", variant="primary") - available_extensions_index = gr.Text(value="https://raw.githubusercontent.com/wiki/AUTOMATIC1111/stable-diffusion-webui/Extensions-index.md", label="Extension index URL").style(container=False) + available_extensions_index = gr.Text(value="https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui-extensions/master/index.json", label="Extension index URL").style(container=False) extension_to_install = gr.Text(elem_id="extension_to_install", visible=False) install_extension_button = gr.Button(elem_id="install_extension_button", visible=False) From 6033de18bff6c1506879e5f3a645f98131c3f043 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 12 Mar 2023 20:50:02 +0300 Subject: [PATCH 094/104] revert webui.sh from #8492 --- webui.sh | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/webui.sh b/webui.sh index 7b6e0568d..8cdad22d3 100755 --- a/webui.sh +++ b/webui.sh @@ -6,18 +6,19 @@ # If run from macOS, load defaults from webui-macos-env.sh if [[ "$OSTYPE" == "darwin"* ]]; then - if [[ -f "$(dirname $0)/webui-macos-env.sh" ]] + if [[ -f webui-macos-env.sh ]] then - source "$(dirname $0)/webui-macos-env.sh" + source ./webui-macos-env.sh fi fi # Read variables from webui-user.sh # shellcheck source=/dev/null -if [[ -f "$(dirname $0)/webui-user.sh" ]] +if [[ -f webui-user.sh ]] then - source "$(dirname $0)/webui-user.sh" + source ./webui-user.sh fi + # Set defaults # Install directory without trailing slash if [[ -z "${install_dir}" ]] @@ -46,12 +47,12 @@ fi # python3 venv without trailing slash (defaults to ${install_dir}/${clone_dir}/venv) if [[ -z "${venv_dir}" ]] then - venv_dir="${install_dir}/${clone_dir}/venv" + venv_dir="venv" fi if [[ -z "${LAUNCH_SCRIPT}" ]] then - LAUNCH_SCRIPT="${install_dir}/${clone_dir}/launch.py" + LAUNCH_SCRIPT="launch.py" fi # this script cannot be run as root by default @@ -139,23 +140,22 @@ then exit 1 fi -if [[ ! -d "${install_dir}/${clone_dir}" ]] +cd "${install_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/, aborting...\e[0m" "${install_dir}"; exit 1; } +if [[ -d "${clone_dir}" ]] then + cd "${clone_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/%s/, aborting...\e[0m" "${install_dir}" "${clone_dir}"; exit 1; } +else printf "\n%s\n" "${delimiter}" printf "Clone stable-diffusion-webui" printf "\n%s\n" "${delimiter}" - mkdir -p "${install_dir}" - "${GIT}" clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git "${install_dir}/${clone_dir}" + "${GIT}" clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git "${clone_dir}" + cd "${clone_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/%s/, aborting...\e[0m" "${install_dir}" "${clone_dir}"; exit 1; } fi printf "\n%s\n" "${delimiter}" printf "Create and activate python venv" printf "\n%s\n" "${delimiter}" -# Make venv_dir absolute -if [[ "${venv_dir}" != /* ]] -then - venv_dir="${install_dir}/${clone_dir}/${venv_dir}" -fi +cd "${install_dir}"/"${clone_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/%s/, aborting...\e[0m" "${install_dir}" "${clone_dir}"; exit 1; } if [[ ! -d "${venv_dir}" ]] then "${python_cmd}" -m venv "${venv_dir}" From a00cd8b9c1e9866a58d135f3b64cc7e0f29c6d47 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 12 Mar 2023 21:04:17 +0300 Subject: [PATCH 095/104] attempt to fix memory monitor with multiple CUDA devices --- modules/memmon.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/modules/memmon.py b/modules/memmon.py index a7060f585..4018edcc7 100644 --- a/modules/memmon.py +++ b/modules/memmon.py @@ -23,12 +23,16 @@ class MemUsageMonitor(threading.Thread): self.data = defaultdict(int) try: - torch.cuda.mem_get_info() + self.cuda_mem_get_info() torch.cuda.memory_stats(self.device) except Exception as e: # AMD or whatever print(f"Warning: caught exception '{e}', memory monitor disabled") self.disabled = True + def cuda_mem_get_info(self): + index = self.device.index if self.device.index is not None else torch.cuda.current_device() + return torch.cuda.mem_get_info(index) + def run(self): if self.disabled: return @@ -43,10 +47,10 @@ class MemUsageMonitor(threading.Thread): self.run_flag.clear() continue - self.data["min_free"] = torch.cuda.mem_get_info()[0] + self.data["min_free"] = self.cuda_mem_get_info()[0] while self.run_flag.is_set(): - free, total = torch.cuda.mem_get_info() # calling with self.device errors, torch bug? + free, total = self.cuda_mem_get_info() self.data["min_free"] = min(self.data["min_free"], free) time.sleep(1 / self.opts.memmon_poll_rate) @@ -70,7 +74,7 @@ class MemUsageMonitor(threading.Thread): def read(self): if not self.disabled: - free, total = torch.cuda.mem_get_info() + free, total = self.cuda_mem_get_info() self.data["free"] = free self.data["total"] = total From dfeee786f903e392dbef1519c7c246b9856ebab3 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 12 Mar 2023 21:25:22 +0300 Subject: [PATCH 096/104] display correct timings after restarting UI --- modules/timer.py | 3 +++ webui.py | 12 ++++++++++++ 2 files changed, 15 insertions(+) diff --git a/modules/timer.py b/modules/timer.py index 57a4f17a1..ba92be336 100644 --- a/modules/timer.py +++ b/modules/timer.py @@ -33,3 +33,6 @@ class Timer: res += ")" return res + + def reset(self): + self.__init__() diff --git a/webui.py b/webui.py index 1a4175afb..aaec79fda 100644 --- a/webui.py +++ b/webui.py @@ -290,24 +290,35 @@ def webui(): wait_on_server(shared.demo) print('Restarting UI...') + startup_timer.reset() + sd_samplers.set_samplers() modules.script_callbacks.script_unloaded_callback() extensions.list_extensions() + startup_timer.record("list extensions") localization.list_localizations(cmd_opts.localizations_dir) modelloader.forbid_loaded_nonbuiltin_upscalers() modules.scripts.reload_scripts() + startup_timer.record("load scripts") + modules.script_callbacks.model_loaded_callback(shared.sd_model) + startup_timer.record("model loaded callback") + modelloader.load_upscalers() + startup_timer.record("load upscalers") for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]: importlib.reload(module) + startup_timer.record("reload script modules") modules.sd_models.list_models() + startup_timer.record("list SD models") shared.reload_hypernetworks() + startup_timer.record("reload hypernetworks") ui_extra_networks.intialize() ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion()) @@ -316,6 +327,7 @@ def webui(): extra_networks.initialize() extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet()) + startup_timer.record("initialize extra networks") if __name__ == "__main__": From a71b7b5ec09a24e8a9bb3385e32862da905af6f1 Mon Sep 17 00:00:00 2001 From: "Alex \"mcmonkey\" Goodwin" Date: Sun, 12 Mar 2023 12:30:31 -0700 Subject: [PATCH 097/104] relocate filename length limit to better spot --- modules/images.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/modules/images.py b/modules/images.py index 4c204fcaa..2ce5c67c1 100644 --- a/modules/images.py +++ b/modules/images.py @@ -512,9 +512,6 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i file_decoration = "-" + file_decoration file_decoration = namegen.apply(file_decoration) + suffix - if hasattr(os, 'statvfs'): - max_name_len = os.statvfs(path).f_namemax - file_decoration = file_decoration[:max_name_len - 5] if add_number: basecount = get_next_sequence_number(path, basename) @@ -576,6 +573,10 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i os.replace(temp_file_path, filename_without_extension + extension) fullfn_without_extension, extension = os.path.splitext(params.filename) + if hasattr(os, 'statvfs'): + max_name_len = os.statvfs(path).f_namemax + fullfn_without_extension = fullfn_without_extension[:max_name_len - len(extension)] + params.filename = fullfn_without_extension + extension _atomically_save_image(image, fullfn_without_extension, extension) image.already_saved_as = fullfn From 48df6d66ea627a1c538aba4d37d9141798fff4d7 Mon Sep 17 00:00:00 2001 From: "Alex \"mcmonkey\" Goodwin" Date: Sun, 12 Mar 2023 12:33:29 -0700 Subject: [PATCH 098/104] add safety check in case of short extensions so eg if a two-letter or empty extension is used, `.txt` would break, this `max` call protects that. --- modules/images.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/images.py b/modules/images.py index 2ce5c67c1..18d1de2fc 100644 --- a/modules/images.py +++ b/modules/images.py @@ -575,7 +575,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i fullfn_without_extension, extension = os.path.splitext(params.filename) if hasattr(os, 'statvfs'): max_name_len = os.statvfs(path).f_namemax - fullfn_without_extension = fullfn_without_extension[:max_name_len - len(extension)] + fullfn_without_extension = fullfn_without_extension[:max_name_len - max(4, len(extension))] params.filename = fullfn_without_extension + extension _atomically_save_image(image, fullfn_without_extension, extension) From af9158a8c708c2f710823b298708c51a4d8ba08f Mon Sep 17 00:00:00 2001 From: "Alex \"mcmonkey\" Goodwin" Date: Sun, 12 Mar 2023 12:36:04 -0700 Subject: [PATCH 099/104] update `fullfn` properly --- modules/images.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/images.py b/modules/images.py index 18d1de2fc..2da988ee6 100644 --- a/modules/images.py +++ b/modules/images.py @@ -577,6 +577,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i max_name_len = os.statvfs(path).f_namemax fullfn_without_extension = fullfn_without_extension[:max_name_len - max(4, len(extension))] params.filename = fullfn_without_extension + extension + fullfn = params.filename _atomically_save_image(image, fullfn_without_extension, extension) image.already_saved_as = fullfn From 4d26c7da57a621815f25929b35977bd6a3958711 Mon Sep 17 00:00:00 2001 From: high_byte Date: Mon, 13 Mar 2023 17:37:29 +0200 Subject: [PATCH 100/104] initialize extra_network_data before use --- modules/processing.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/modules/processing.py b/modules/processing.py index 06e7a4404..59717b4c6 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -583,6 +583,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if state.job_count == -1: state.job_count = p.n_iter + extra_network_data = None for n in range(p.n_iter): p.iteration = n @@ -712,7 +713,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if opts.grid_save: images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True) - if not p.disable_extra_networks: + if not p.disable_extra_networks and extra_network_data: extra_networks.deactivate(p, extra_network_data) devices.torch_gc() From 03a80f198e8cda66cb6206cb3ae505d66d9eed9d Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Mon, 13 Mar 2023 12:35:30 -0400 Subject: [PATCH 101/104] add pbar to unipc --- modules/models/diffusion/uni_pc/sampler.py | 2 +- modules/models/diffusion/uni_pc/uni_pc.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/modules/models/diffusion/uni_pc/sampler.py b/modules/models/diffusion/uni_pc/sampler.py index bf346ff48..a241c8a7c 100644 --- a/modules/models/diffusion/uni_pc/sampler.py +++ b/modules/models/diffusion/uni_pc/sampler.py @@ -71,7 +71,7 @@ class UniPCSampler(object): # sampling C, H, W = shape size = (batch_size, C, H, W) - print(f'Data shape for UniPC sampling is {size}') + # print(f'Data shape for UniPC sampling is {size}') device = self.model.betas.device if x_T is None: diff --git a/modules/models/diffusion/uni_pc/uni_pc.py b/modules/models/diffusion/uni_pc/uni_pc.py index e9a093a2b..eb5f4e762 100644 --- a/modules/models/diffusion/uni_pc/uni_pc.py +++ b/modules/models/diffusion/uni_pc/uni_pc.py @@ -1,6 +1,7 @@ import torch import torch.nn.functional as F import math +from tqdm.auto import trange class NoiseScheduleVP: @@ -750,7 +751,7 @@ class UniPC: if method == 'multistep': assert steps >= order, "UniPC order must be < sampling steps" timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) - print(f"Running UniPC Sampling with {timesteps.shape[0]} timesteps, order {order}") + #print(f"Running UniPC Sampling with {timesteps.shape[0]} timesteps, order {order}") assert timesteps.shape[0] - 1 == steps with torch.no_grad(): vec_t = timesteps[0].expand((x.shape[0])) @@ -766,7 +767,7 @@ class UniPC: self.after_update(x, model_x) model_prev_list.append(model_x) t_prev_list.append(vec_t) - for step in range(order, steps + 1): + for step in trange(order, steps + 1): vec_t = timesteps[step].expand(x.shape[0]) if lower_order_final: step_order = min(order, steps + 1 - step) From c19530f1a590d758463f84523dd4c48c34d723e6 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 14 Mar 2023 09:10:26 +0300 Subject: [PATCH 102/104] Add view metadata button for Lora cards. --- extensions-builtin/Lora/lora.py | 21 ++++++- .../Lora/ui_extra_networks_lora.py | 1 + html/extra-networks-card.html | 2 + javascript/extraNetworks.js | 38 +++++++++++- modules/sd_models.py | 24 ++++++++ modules/ui_extra_networks.py | 7 +++ style.css | 61 +++++++++++++++++++ 7 files changed, 152 insertions(+), 2 deletions(-) diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index cb8f1d36a..8937b585e 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -3,7 +3,9 @@ import os import re import torch -from modules import shared, devices, sd_models +from modules import shared, devices, sd_models, errors + +metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20} re_digits = re.compile(r"\d+") re_unet_down_blocks = re.compile(r"lora_unet_down_blocks_(\d+)_attentions_(\d+)_(.+)") @@ -43,6 +45,23 @@ class LoraOnDisk: def __init__(self, name, filename): self.name = name self.filename = filename + self.metadata = {} + + _, ext = os.path.splitext(filename) + if ext.lower() == ".safetensors": + try: + self.metadata = sd_models.read_metadata_from_safetensors(filename) + except Exception as e: + errors.display(e, f"reading lora {filename}") + + if self.metadata: + m = {} + for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)): + m[k] = v + + self.metadata = m + + self.ssmd_cover_images = self.metadata.pop('ssmd_cover_images', None) # those are cover images and they are too big to display in UI as text class LoraModule: diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index 8d32052ec..68b113323 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -23,6 +23,7 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): "search_term": self.search_terms_from_path(lora_on_disk.filename), "prompt": json.dumps(f""), "local_preview": f"{path}.{shared.opts.samples_format}", + "metadata": json.dumps(lora_on_disk.metadata, indent=4) if lora_on_disk.metadata else None, } def allowed_directories_for_previews(self): diff --git a/html/extra-networks-card.html b/html/extra-networks-card.html index 8612396d2..1bf3fc30d 100644 --- a/html/extra-networks-card.html +++ b/html/extra-networks-card.html @@ -1,4 +1,6 @@
+ {metadata_button} +
    diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index d0177ad64..2fb87cd5b 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -102,4 +102,40 @@ function extraNetworksSearchButton(tabs_id, event){ searchTextarea.value = text updateInput(searchTextarea) -} \ No newline at end of file +} + +var globalPopup = null; +var globalPopupInner = null; +function popup(contents){ + if(! globalPopup){ + globalPopup = document.createElement('div') + globalPopup.onclick = function(){ globalPopup.style.display = "none"; }; + globalPopup.classList.add('global-popup'); + + var close = document.createElement('div') + close.classList.add('global-popup-close'); + close.onclick = function(){ globalPopup.style.display = "none"; }; + close.title = "Close"; + globalPopup.appendChild(close) + + globalPopupInner = document.createElement('div') + globalPopupInner.onclick = function(event){ event.stopPropagation(); return false; }; + globalPopupInner.classList.add('global-popup-inner'); + globalPopup.appendChild(globalPopupInner) + + gradioApp().appendChild(globalPopup); + } + + globalPopupInner.innerHTML = ''; + globalPopupInner.appendChild(contents); + + globalPopup.style.display = "flex"; +} + +function extraNetworksShowMetadata(text){ + elem = document.createElement('pre') + elem.classList.add('popup-metadata'); + elem.textContent = text; + + popup(elem); +} diff --git a/modules/sd_models.py b/modules/sd_models.py index 93959f55f..5f57ec0c3 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -210,6 +210,30 @@ def get_state_dict_from_checkpoint(pl_sd): return pl_sd +def read_metadata_from_safetensors(filename): + import json + + with open(filename, mode="rb") as file: + metadata_len = file.read(8) + metadata_len = int.from_bytes(metadata_len, "little") + json_start = file.read(2) + + assert metadata_len > 2 and json_start in (b'{"', b"{'"), f"{filename} is not a safetensors file" + json_data = json_start + file.read(metadata_len-2) + json_obj = json.loads(json_data) + + res = {} + for k, v in json_obj.get("__metadata__", {}).items(): + res[k] = v + if isinstance(v, str) and v[0] == '{': + try: + res[k] = json.loads(v) + except Exception as e: + pass + + return res + + def read_state_dict(checkpoint_file, print_global_state=False, map_location=None): _, extension = os.path.splitext(checkpoint_file) if extension.lower() == ".safetensors": diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 01df5e90b..b5847299e 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -124,6 +124,12 @@ class ExtraNetworksPage: if onclick is None: onclick = '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"' + metadata_button = "" + metadata = item.get("metadata") + if metadata: + metadata_onclick = '"' + html.escape(f"""extraNetworksShowMetadata({json.dumps(metadata)}); return false;""") + '"' + metadata_button = f"" + args = { "preview_html": "style='background-image: url(\"" + html.escape(preview) + "\")'" if preview else '', "prompt": item.get("prompt", None), @@ -134,6 +140,7 @@ class ExtraNetworksPage: "card_clicked": onclick, "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"', "search_term": item.get("search_term", ""), + "metadata_button": metadata_button, } return self.card_page.format(**args) diff --git a/style.css b/style.css index 2f26ad02b..3eac2b176 100644 --- a/style.css +++ b/style.css @@ -362,6 +362,46 @@ input[type="range"]{ height: 100%; } +.popup-metadata{ + color: black; + background: white; + display: inline-block; + padding: 1em; + white-space: pre-wrap; +} + +.global-popup{ + display: flex; + position: fixed; + z-index: 1001; + left: 0; + top: 0; + width: 100%; + height: 100%; + overflow: auto; + background-color: rgba(20, 20, 20, 0.95); +} + + +.global-popup-close:before { + content: "×"; +} + +.global-popup-close{ + position: fixed; + right: 0.25em; + top: 0; + cursor: pointer; + color: white; + font-size: 32pt; +} + +.global-popup-inner{ + display: inline-block; + margin: auto; + padding: 2em; +} + #lightboxModal{ display: none; position: fixed; @@ -837,6 +877,27 @@ footer { margin-left: 0.5em; } + +.extra-network-cards .card .metadata-button:before, .extra-network-thumbs .card .metadata-button:before{ + content: "🛈"; +} +.extra-network-cards .card .metadata-button, .extra-network-thumbs .card .metadata-button{ + display: none; + position: absolute; + right: 0; + color: white; + text-shadow: 2px 2px 3px black; + padding: 0.25em; + font-size: 22pt; +} +.extra-network-cards .card:hover .metadata-button, .extra-network-thumbs .card:hover .metadata-button{ + display: inline-block; +} +.extra-network-cards .card .metadata-button:hover, .extra-network-thumbs .card .metadata-button:hover{ + color: red; +} + + .extra-network-thumbs { display: flex; flex-flow: row wrap; From 4281432594084bc846cc834c807ac57f59457eae Mon Sep 17 00:00:00 2001 From: willtakasan <126040162+willtakasan@users.noreply.github.com> Date: Tue, 14 Mar 2023 15:36:08 +0900 Subject: [PATCH 103/104] Update ui_extra_networks.py I updated it so that no error message is displayed when setting a webp for the preview image. --- modules/ui_extra_networks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index b5847299e..cdfd6f2a0 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -30,8 +30,8 @@ def add_pages_to_demo(app): raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.") ext = os.path.splitext(filename)[1].lower() - if ext not in (".png", ".jpg"): - raise ValueError(f"File cannot be fetched: {filename}. Only png and jpg.") + if ext not in (".png", ".jpg", ".webp"): + raise ValueError(f"File cannot be fetched: {filename}. Only png and jpg and webp.") # would profit from returning 304 return FileResponse(filename, headers={"Accept-Ranges": "bytes"}) From 6a04a7f20fcc4a992ae017b06723e9ceffe17b37 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 14 Mar 2023 11:22:29 +0300 Subject: [PATCH 104/104] fix an error loading Lora with empty values in metadata --- modules/sd_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 5f57ec0c3..f0cb12400 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -225,7 +225,7 @@ def read_metadata_from_safetensors(filename): res = {} for k, v in json_obj.get("__metadata__", {}).items(): res[k] = v - if isinstance(v, str) and v[0] == '{': + if isinstance(v, str) and v[0:1] == '{': try: res[k] = json.loads(v) except Exception as e: