From 57e6d05a43e4bdf4575e520f1a04c17e80fe58cc Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 9 Jun 2024 21:18:36 +0300 Subject: [PATCH] added tool for profiling code --- modules/call_queue.py | 10 +++++++-- modules/processing.py | 5 +++-- modules/profiling.py | 46 +++++++++++++++++++++++++++++++++++++++ modules/shared_options.py | 16 ++++++++++++++ style.css | 6 ++++- 5 files changed, 78 insertions(+), 5 deletions(-) create mode 100644 modules/profiling.py diff --git a/modules/call_queue.py b/modules/call_queue.py index b50931bcd..d22c23b31 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -1,8 +1,9 @@ +import os.path from functools import wraps import html import time -from modules import shared, progress, errors, devices, fifo_lock +from modules import shared, progress, errors, devices, fifo_lock, profiling queue_lock = fifo_lock.FIFOLock() @@ -111,8 +112,13 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False): else: vram_html = '' + if shared.opts.profiling_enable and os.path.exists(shared.opts.profiling_filename): + profiling_html = f"

[ Profile ]

" + else: + profiling_html = '' + # last item is always HTML - res[-1] += f"

Time taken: {elapsed_text}

{vram_html}
" + res[-1] += f"

Time taken: {elapsed_text}

{vram_html}{profiling_html}
" return tuple(res) diff --git a/modules/processing.py b/modules/processing.py index 65e37db0a..91cb94db1 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -16,7 +16,7 @@ from skimage import exposure from typing import Any import modules.sd_hijack -from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infotext_utils, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng +from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infotext_utils, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng, profiling from modules.rng import slerp # noqa: F401 from modules.sd_hijack import model_hijack from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes @@ -843,7 +843,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed: # backwards compatibility, fix sampler and scheduler if invalid sd_samplers.fix_p_invalid_sampler_and_scheduler(p) - res = process_images_inner(p) + with profiling.Profiler(): + res = process_images_inner(p) finally: sd_models.apply_token_merging(p.sd_model, 0) diff --git a/modules/profiling.py b/modules/profiling.py new file mode 100644 index 000000000..95b59f71a --- /dev/null +++ b/modules/profiling.py @@ -0,0 +1,46 @@ +import torch + +from modules import shared, ui_gradio_extensions + + +class Profiler: + def __init__(self): + if not shared.opts.profiling_enable: + self.profiler = None + return + + activities = [] + if "CPU" in shared.opts.profiling_activities: + activities.append(torch.profiler.ProfilerActivity.CPU) + if "CUDA" in shared.opts.profiling_activities: + activities.append(torch.profiler.ProfilerActivity.CUDA) + + if not activities: + self.profiler = None + return + + self.profiler = torch.profiler.profile( + activities=activities, + record_shapes=shared.opts.profiling_record_shapes, + profile_memory=shared.opts.profiling_profile_memory, + with_stack=shared.opts.profiling_with_stack + ) + + def __enter__(self): + if self.profiler: + self.profiler.__enter__() + + return self + + def __exit__(self, exc_type, exc, exc_tb): + if self.profiler: + shared.state.textinfo = "Finishing profile..." + + self.profiler.__exit__(exc_type, exc, exc_tb) + + self.profiler.export_chrome_trace(shared.opts.profiling_filename) + + +def webpath(): + return ui_gradio_extensions.webpath(shared.opts.profiling_filename) + diff --git a/modules/shared_options.py b/modules/shared_options.py index e2e02094f..104d8a544 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -129,6 +129,22 @@ options_templates.update(options_section(('system', "System", "system"), { "dump_stacks_on_signal": OptionInfo(False, "Print stack traces before exiting the program with ctrl+c."), })) +options_templates.update(options_section(('profiler', "Profiler", "system"), { + "profiling_explanation": OptionHTML(""" +Those settings allow you to enable torch profiler when generating pictures. +Profiling allows you to see which code uses how much of computer's resources during generation. +Each generation writes its own profile to one file, overwriting previous. +The file can be viewed in Chrome, or on a Perfetto web site. +Warning: writing profile can take a lot of time, up to 30 seconds, and the file itelf can be around 500MB in size. +"""), + "profiling_enable": OptionInfo(False, "Enable profiling"), + "profiling_activities": OptionInfo(["CPU"], "Activities", gr.CheckboxGroup, {"choices": ["CPU", "CUDA"]}), + "profiling_record_shapes": OptionInfo(True, "Record shapes"), + "profiling_profile_memory": OptionInfo(True, "Profile memory"), + "profiling_with_stack": OptionInfo(True, "Include python stack"), + "profiling_filename": OptionInfo("trace.json", "Profile filename"), +})) + options_templates.update(options_section(('API', "API", "system"), { "api_enable_requests": OptionInfo(True, "Allow http:// and https:// URLs for input images in API", restrict_api=True), "api_forbid_local_requests": OptionInfo(True, "Forbid URLs to local resources", restrict_api=True), diff --git a/style.css b/style.css index 467c29cdf..64ef61bad 100644 --- a/style.css +++ b/style.css @@ -279,7 +279,7 @@ input[type="checkbox"].input-accordion-checkbox{ display: inline-block; } -.html-log .performance p.time, .performance p.vram, .performance p.time abbr, .performance p.vram abbr { +.html-log .performance p.time, .performance p.vram, .performance p.profile, .performance p.time abbr, .performance p.vram abbr { margin-bottom: 0; color: var(--block-title-text-color); } @@ -291,6 +291,10 @@ input[type="checkbox"].input-accordion-checkbox{ margin-left: auto; } +.html-log .performance p.profile { + margin-left: 0.5em; +} + .html-log .performance .measurement{ color: var(--body-text-color); font-weight: bold;