Textual inversion support
This commit is contained in:
parent
34e9795505
commit
4b0188dcbf
25
webui.py
25
webui.py
|
@ -51,6 +51,7 @@ parser.add_argument("--no-half", action='store_true', help="do not switch the mo
|
||||||
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware accleration in browser)")
|
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware accleration in browser)")
|
||||||
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
|
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
|
||||||
parser.add_argument("--grid-format", type=str, default='png', help="file format for saved grids; can be png or jpg")
|
parser.add_argument("--grid-format", type=str, default='png', help="file format for saved grids; can be png or jpg")
|
||||||
|
parser.add_argument("--inversion", action='store_true', help="switch to stable inversion version; allows for uploading embeddings; this option should be used only with textual inversion repo")
|
||||||
opt = parser.parse_args()
|
opt = parser.parse_args()
|
||||||
|
|
||||||
GFPGAN_dir = opt.gfpgan_dir
|
GFPGAN_dir = opt.gfpgan_dir
|
||||||
|
@ -151,8 +152,8 @@ if os.path.exists(GFPGAN_dir):
|
||||||
print("Error loading GFPGAN:", file=sys.stderr)
|
print("Error loading GFPGAN:", file=sys.stderr)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
|
||||||
config = OmegaConf.load("configs/stable-diffusion/v1-inference.yaml")
|
config = OmegaConf.load(opt.config)
|
||||||
model = load_model_from_config(config, "models/ldm/stable-diffusion-v1/model.ckpt")
|
model = load_model_from_config(config, opt.ckpt)
|
||||||
|
|
||||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||||
model = (model if opt.no_half else model.half()).to(device)
|
model = (model if opt.no_half else model.half()).to(device)
|
||||||
|
@ -419,9 +420,16 @@ Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{',
|
||||||
return output_images, seed, info
|
return output_images, seed, info
|
||||||
|
|
||||||
|
|
||||||
def txt2img(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, prompt_matrix: bool, ddim_eta: float, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int):
|
def load_embeddings(fp):
|
||||||
|
# load the file
|
||||||
|
model.embedding_manager.load(fp.name)
|
||||||
|
|
||||||
|
|
||||||
|
def txt2img(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, prompt_matrix: bool, ddim_eta: float, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int, embeddings_fp):
|
||||||
outpath = opt.outdir or "outputs/txt2img-samples"
|
outpath = opt.outdir or "outputs/txt2img-samples"
|
||||||
|
|
||||||
|
load_embeddings(embeddings_fp)
|
||||||
|
|
||||||
if sampler_name == 'PLMS':
|
if sampler_name == 'PLMS':
|
||||||
sampler = PLMSSampler(model)
|
sampler = PLMSSampler(model)
|
||||||
elif sampler_name == 'DDIM':
|
elif sampler_name == 'DDIM':
|
||||||
|
@ -516,6 +524,7 @@ txt2img_interface = gr.Interface(
|
||||||
gr.Number(label='Seed', value=-1),
|
gr.Number(label='Seed', value=-1),
|
||||||
gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512),
|
gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512),
|
||||||
gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512),
|
gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512),
|
||||||
|
gr.File(label = "Embeddings file for textual inversion", visible=opt.inversion)
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
gr.Gallery(label="Images"),
|
gr.Gallery(label="Images"),
|
||||||
|
@ -528,9 +537,11 @@ txt2img_interface = gr.Interface(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def img2img(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, prompt_matrix, loopback: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int):
|
def img2img(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, prompt_matrix, loopback: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int, embeddings_fp):
|
||||||
outpath = opt.outdir or "outputs/img2img-samples"
|
outpath = opt.outdir or "outputs/img2img-samples"
|
||||||
|
|
||||||
|
load_embeddings(embeddings_fp)
|
||||||
|
|
||||||
sampler = KDiffusionSampler(model)
|
sampler = KDiffusionSampler(model)
|
||||||
|
|
||||||
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
||||||
|
@ -644,7 +655,8 @@ img2img_interface = gr.Interface(
|
||||||
gr.Number(label='Seed', value=-1),
|
gr.Number(label='Seed', value=-1),
|
||||||
gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512),
|
gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512),
|
||||||
gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512),
|
gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512),
|
||||||
gr.Radio(label="Resize mode", choices=["Just resize", "Crop and resize", "Resize and fill"], type="index", value="Just resize")
|
gr.Radio(label="Resize mode", choices=["Just resize", "Crop and resize", "Resize and fill"], type="index", value="Just resize"),
|
||||||
|
gr.File(label = "Embeddings file for textual inversion", visible=opt.inversion)
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
gr.Gallery(),
|
gr.Gallery(),
|
||||||
|
@ -688,10 +700,11 @@ if GFPGAN is not None:
|
||||||
allow_flagging="never",
|
allow_flagging="never",
|
||||||
), "GFPGAN"))
|
), "GFPGAN"))
|
||||||
|
|
||||||
|
|
||||||
demo = gr.TabbedInterface(
|
demo = gr.TabbedInterface(
|
||||||
interface_list=[x[0] for x in interfaces],
|
interface_list=[x[0] for x in interfaces],
|
||||||
tab_names=[x[1] for x in interfaces],
|
tab_names=[x[1] for x in interfaces],
|
||||||
css=("" if opt.no_progressbar_hiding else css_hide_progressbar)
|
css=("" if opt.no_progressbar_hiding else css_hide_progressbar)
|
||||||
)
|
)
|
||||||
|
|
||||||
demo.launch()
|
demo.launch()
|
Loading…
Reference in New Issue