[Tests] Speed up slow tests (#1040)

* [Tests] Speed up slow tests

* Up

* up
This commit is contained in:
Patrick von Platen 2022-10-28 14:46:39 +02:00 committed by GitHub
parent a80480f0f2
commit d2d9764f35
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 58 additions and 36 deletions

View File

@ -86,7 +86,7 @@ class PipelineIntegrationTests(unittest.TestCase):
def test_dance_diffusion(self): def test_dance_diffusion(self):
device = torch_device device = torch_device
pipe = DanceDiffusionPipeline.from_pretrained("harmonai/maestro-150k") pipe = DanceDiffusionPipeline.from_pretrained("harmonai/maestro-150k", device_map="auto")
pipe = pipe.to(device) pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
@ -103,7 +103,9 @@ class PipelineIntegrationTests(unittest.TestCase):
def test_dance_diffusion_fp16(self): def test_dance_diffusion_fp16(self):
device = torch_device device = torch_device
pipe = DanceDiffusionPipeline.from_pretrained("harmonai/maestro-150k", torch_dtype=torch.float16) pipe = DanceDiffusionPipeline.from_pretrained(
"harmonai/maestro-150k", torch_dtype=torch.float16, device_map="auto"
)
pipe = pipe.to(device) pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)

View File

@ -78,7 +78,7 @@ class DDIMPipelineIntegrationTests(unittest.TestCase):
def test_inference_ema_bedroom(self): def test_inference_ema_bedroom(self):
model_id = "google/ddpm-ema-bedroom-256" model_id = "google/ddpm-ema-bedroom-256"
unet = UNet2DModel.from_pretrained(model_id) unet = UNet2DModel.from_pretrained(model_id, device_map="auto")
scheduler = DDIMScheduler.from_config(model_id) scheduler = DDIMScheduler.from_config(model_id)
ddpm = DDIMPipeline(unet=unet, scheduler=scheduler) ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
@ -97,7 +97,7 @@ class DDIMPipelineIntegrationTests(unittest.TestCase):
def test_inference_cifar10(self): def test_inference_cifar10(self):
model_id = "google/ddpm-cifar10-32" model_id = "google/ddpm-cifar10-32"
unet = UNet2DModel.from_pretrained(model_id) unet = UNet2DModel.from_pretrained(model_id, device_map="auto")
scheduler = DDIMScheduler() scheduler = DDIMScheduler()
ddim = DDIMPipeline(unet=unet, scheduler=scheduler) ddim = DDIMPipeline(unet=unet, scheduler=scheduler)

View File

@ -38,7 +38,7 @@ class DDPMPipelineIntegrationTests(unittest.TestCase):
def test_inference_cifar10(self): def test_inference_cifar10(self):
model_id = "google/ddpm-cifar10-32" model_id = "google/ddpm-cifar10-32"
unet = UNet2DModel.from_pretrained(model_id) unet = UNet2DModel.from_pretrained(model_id, device_map="auto")
scheduler = DDPMScheduler.from_config(model_id) scheduler = DDPMScheduler.from_config(model_id)
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler) ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)

View File

@ -70,7 +70,7 @@ class KarrasVePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
class KarrasVePipelineIntegrationTests(unittest.TestCase): class KarrasVePipelineIntegrationTests(unittest.TestCase):
def test_inference(self): def test_inference(self):
model_id = "google/ncsnpp-celebahq-256" model_id = "google/ncsnpp-celebahq-256"
model = UNet2DModel.from_pretrained(model_id) model = UNet2DModel.from_pretrained(model_id, device_map="auto")
scheduler = KarrasVeScheduler() scheduler = KarrasVeScheduler()
pipe = KarrasVePipeline(unet=model, scheduler=scheduler) pipe = KarrasVePipeline(unet=model, scheduler=scheduler)

View File

@ -121,7 +121,7 @@ class LDMTextToImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
@require_torch @require_torch
class LDMTextToImagePipelineIntegrationTests(unittest.TestCase): class LDMTextToImagePipelineIntegrationTests(unittest.TestCase):
def test_inference_text2img(self): def test_inference_text2img(self):
ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256") ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256", device_map="auto")
ldm.to(torch_device) ldm.to(torch_device)
ldm.set_progress_bar_config(disable=None) ldm.set_progress_bar_config(disable=None)
@ -138,7 +138,7 @@ class LDMTextToImagePipelineIntegrationTests(unittest.TestCase):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_inference_text2img_fast(self): def test_inference_text2img_fast(self):
ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256") ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256", device_map="auto")
ldm.to(torch_device) ldm.to(torch_device)
ldm.set_progress_bar_config(disable=None) ldm.set_progress_bar_config(disable=None)

View File

@ -71,7 +71,7 @@ class PNDMPipelineIntegrationTests(unittest.TestCase):
def test_inference_cifar10(self): def test_inference_cifar10(self):
model_id = "google/ddpm-cifar10-32" model_id = "google/ddpm-cifar10-32"
unet = UNet2DModel.from_pretrained(model_id) unet = UNet2DModel.from_pretrained(model_id, device_map="auto")
scheduler = PNDMScheduler() scheduler = PNDMScheduler()
pndm = PNDMPipeline(unet=unet, scheduler=scheduler) pndm = PNDMPipeline(unet=unet, scheduler=scheduler)

View File

@ -72,7 +72,7 @@ class ScoreSdeVeipelineFastTests(PipelineTesterMixin, unittest.TestCase):
class ScoreSdeVePipelineIntegrationTests(unittest.TestCase): class ScoreSdeVePipelineIntegrationTests(unittest.TestCase):
def test_inference(self): def test_inference(self):
model_id = "google/ncsnpp-church-256" model_id = "google/ncsnpp-church-256"
model = UNet2DModel.from_pretrained(model_id) model = UNet2DModel.from_pretrained(model_id, device_map="auto")
scheduler = ScoreSdeVeScheduler.from_config(model_id) scheduler = ScoreSdeVeScheduler.from_config(model_id)

View File

@ -528,7 +528,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
def test_stable_diffusion(self): def test_stable_diffusion(self):
# make sure here that pndm scheduler skips prk # make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1") sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1", device_map="auto")
sd_pipe = sd_pipe.to(torch_device) sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None) sd_pipe.set_progress_bar_config(disable=None)
@ -548,7 +548,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_fast_ddim(self): def test_stable_diffusion_fast_ddim(self):
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1") sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1", device_map="auto")
sd_pipe = sd_pipe.to(torch_device) sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None) sd_pipe.set_progress_bar_config(disable=None)
@ -576,7 +576,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
def test_lms_stable_diffusion_pipeline(self): def test_lms_stable_diffusion_pipeline(self):
model_id = "CompVis/stable-diffusion-v1-1" model_id = "CompVis/stable-diffusion-v1-1"
pipe = StableDiffusionPipeline.from_pretrained(model_id).to(torch_device) pipe = StableDiffusionPipeline.from_pretrained(model_id, device_map="auto").to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
scheduler = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler") scheduler = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler")
pipe.scheduler = scheduler pipe.scheduler = scheduler
@ -595,9 +595,10 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
def test_stable_diffusion_memory_chunking(self): def test_stable_diffusion_memory_chunking(self):
torch.cuda.reset_peak_memory_stats() torch.cuda.reset_peak_memory_stats()
model_id = "CompVis/stable-diffusion-v1-4" model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16).to( pipe = StableDiffusionPipeline.from_pretrained(
torch_device model_id, revision="fp16", torch_dtype=torch.float16, device_map="auto"
) )
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
prompt = "a photograph of an astronaut riding a horse" prompt = "a photograph of an astronaut riding a horse"
@ -633,9 +634,10 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
def test_stable_diffusion_text2img_pipeline_fp16(self): def test_stable_diffusion_text2img_pipeline_fp16(self):
torch.cuda.reset_peak_memory_stats() torch.cuda.reset_peak_memory_stats()
model_id = "CompVis/stable-diffusion-v1-4" model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16).to( pipe = StableDiffusionPipeline.from_pretrained(
torch_device model_id, revision="fp16", device_map="auto", torch_dtype=torch.float16
) )
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
prompt = "a photograph of an astronaut riding a horse" prompt = "a photograph of an astronaut riding a horse"
@ -670,6 +672,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
pipe = StableDiffusionPipeline.from_pretrained( pipe = StableDiffusionPipeline.from_pretrained(
model_id, model_id,
safety_checker=None, safety_checker=None,
device_map="auto",
) )
pipe.to(torch_device) pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
@ -711,7 +714,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
test_callback_fn.has_been_called = False test_callback_fn.has_been_called = False
pipe = StableDiffusionPipeline.from_pretrained( pipe = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16 "CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, device_map="auto"
) )
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
@ -737,7 +740,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
start_time = time.time() start_time = time.time()
pipeline_normal_load = StableDiffusionPipeline.from_pretrained( pipeline_normal_load = StableDiffusionPipeline.from_pretrained(
pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True pipeline_id, revision="fp16", torch_dtype=torch.float16, device_map="auto"
) )
pipeline_normal_load.to(torch_device) pipeline_normal_load.to(torch_device)
normal_load_time = time.time() - start_time normal_load_time = time.time() - start_time
@ -758,7 +761,9 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
pipeline_id = "CompVis/stable-diffusion-v1-4" pipeline_id = "CompVis/stable-diffusion-v1-4"
prompt = "Andromeda galaxy in a bottle" prompt = "Andromeda galaxy in a bottle"
pipeline = StableDiffusionPipeline.from_pretrained(pipeline_id, revision="fp16", torch_dtype=torch.float16) pipeline = StableDiffusionPipeline.from_pretrained(
pipeline_id, revision="fp16", torch_dtype=torch.float16, device_map="auto"
)
pipeline.enable_attention_slicing(1) pipeline.enable_attention_slicing(1)
pipeline.enable_sequential_cpu_offload() pipeline.enable_sequential_cpu_offload()

View File

@ -488,6 +488,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
pipe = StableDiffusionImg2ImgPipeline.from_pretrained( pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
model_id, model_id,
safety_checker=None, safety_checker=None,
device_map="auto",
) )
pipe.to(torch_device) pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
@ -529,6 +530,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
model_id, model_id,
scheduler=lms, scheduler=lms,
safety_checker=None, safety_checker=None,
device_map="auto",
) )
pipe.to(torch_device) pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
@ -580,7 +582,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
init_image = init_image.resize((768, 512)) init_image = init_image.resize((768, 512))
pipe = StableDiffusionImg2ImgPipeline.from_pretrained( pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16 "CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, device_map="auto"
) )
pipe.to(torch_device) pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)

View File

@ -288,6 +288,7 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
pipe = StableDiffusionInpaintPipeline.from_pretrained( pipe = StableDiffusionInpaintPipeline.from_pretrained(
model_id, model_id,
safety_checker=None, safety_checker=None,
device_map="auto",
) )
pipe.to(torch_device) pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
@ -329,6 +330,7 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
revision="fp16", revision="fp16",
torch_dtype=torch.float16, torch_dtype=torch.float16,
safety_checker=None, safety_checker=None,
device_map="auto",
) )
pipe.to(torch_device) pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
@ -366,7 +368,9 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
pndm = PNDMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True) pndm = PNDMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True)
model_id = "runwayml/stable-diffusion-inpainting" model_id = "runwayml/stable-diffusion-inpainting"
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None, scheduler=pndm) pipe = StableDiffusionInpaintPipeline.from_pretrained(
model_id, safety_checker=None, scheduler=pndm, device_map="auto"
)
pipe.to(torch_device) pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing() pipe.enable_attention_slicing()

View File

@ -368,6 +368,7 @@ class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase):
pipe = StableDiffusionInpaintPipeline.from_pretrained( pipe = StableDiffusionInpaintPipeline.from_pretrained(
model_id, model_id,
safety_checker=None, safety_checker=None,
device_map="auto",
) )
pipe.to(torch_device) pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
@ -413,6 +414,7 @@ class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase):
model_id, model_id,
scheduler=lms, scheduler=lms,
safety_checker=None, safety_checker=None,
device_map="auto",
) )
pipe.to(torch_device) pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
@ -469,7 +471,7 @@ class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase):
) )
pipe = StableDiffusionInpaintPipeline.from_pretrained( pipe = StableDiffusionInpaintPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16 "CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, device_map="auto"
) )
pipe.to(torch_device) pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)

View File

@ -108,8 +108,8 @@ class CustomPipelineTests(unittest.TestCase):
def test_load_pipeline_from_git(self): def test_load_pipeline_from_git(self):
clip_model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K" clip_model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
feature_extractor = CLIPFeatureExtractor.from_pretrained(clip_model_id) feature_extractor = CLIPFeatureExtractor.from_pretrained(clip_model_id, device_map="auto")
clip_model = CLIPModel.from_pretrained(clip_model_id, torch_dtype=torch.float16) clip_model = CLIPModel.from_pretrained(clip_model_id, torch_dtype=torch.float16, device_map="auto")
pipeline = DiffusionPipeline.from_pretrained( pipeline = DiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", "CompVis/stable-diffusion-v1-4",
@ -118,6 +118,7 @@ class CustomPipelineTests(unittest.TestCase):
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
torch_dtype=torch.float16, torch_dtype=torch.float16,
revision="fp16", revision="fp16",
device_map="auto",
) )
pipeline.enable_attention_slicing() pipeline.enable_attention_slicing()
pipeline = pipeline.to(torch_device) pipeline = pipeline.to(torch_device)
@ -312,7 +313,9 @@ class PipelineSlowTests(unittest.TestCase):
def test_smart_download(self): def test_smart_download(self):
model_id = "hf-internal-testing/unet-pipeline-dummy" model_id = "hf-internal-testing/unet-pipeline-dummy"
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
_ = DiffusionPipeline.from_pretrained(model_id, cache_dir=tmpdirname, force_download=True) _ = DiffusionPipeline.from_pretrained(
model_id, cache_dir=tmpdirname, force_download=True, device_map="auto"
)
local_repo_name = "--".join(["models"] + model_id.split("/")) local_repo_name = "--".join(["models"] + model_id.split("/"))
snapshot_dir = os.path.join(tmpdirname, local_repo_name, "snapshots") snapshot_dir = os.path.join(tmpdirname, local_repo_name, "snapshots")
snapshot_dir = os.path.join(snapshot_dir, os.listdir(snapshot_dir)[0]) snapshot_dir = os.path.join(snapshot_dir, os.listdir(snapshot_dir)[0])
@ -335,7 +338,9 @@ class PipelineSlowTests(unittest.TestCase):
logger = logging.get_logger("diffusers.pipeline_utils") logger = logging.get_logger("diffusers.pipeline_utils")
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
with CaptureLogger(logger) as cap_logger: with CaptureLogger(logger) as cap_logger:
DiffusionPipeline.from_pretrained(model_id, not_used=True, cache_dir=tmpdirname, force_download=True) DiffusionPipeline.from_pretrained(
model_id, not_used=True, cache_dir=tmpdirname, force_download=True, device_map="auto"
)
assert cap_logger.out == "Keyword arguments {'not_used': True} not recognized.\n" assert cap_logger.out == "Keyword arguments {'not_used': True} not recognized.\n"
@ -358,7 +363,7 @@ class PipelineSlowTests(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
ddpm.save_pretrained(tmpdirname) ddpm.save_pretrained(tmpdirname)
new_ddpm = DDPMPipeline.from_pretrained(tmpdirname) new_ddpm = DDPMPipeline.from_pretrained(tmpdirname, device_map="auto")
new_ddpm.to(torch_device) new_ddpm.to(torch_device)
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
@ -374,10 +379,10 @@ class PipelineSlowTests(unittest.TestCase):
scheduler = DDPMScheduler(num_train_timesteps=10) scheduler = DDPMScheduler(num_train_timesteps=10)
ddpm = DDPMPipeline.from_pretrained(model_path, scheduler=scheduler) ddpm = DDPMPipeline.from_pretrained(model_path, scheduler=scheduler, device_map="auto")
ddpm.to(torch_device) ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None) ddpm.set_progress_bar_config(disable=None)
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler) ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler, device_map="auto")
ddpm_from_hub.to(torch_device) ddpm_from_hub.to(torch_device)
ddpm_from_hub.set_progress_bar_config(disable=None) ddpm_from_hub.set_progress_bar_config(disable=None)
@ -395,12 +400,14 @@ class PipelineSlowTests(unittest.TestCase):
scheduler = DDPMScheduler(num_train_timesteps=10) scheduler = DDPMScheduler(num_train_timesteps=10)
# pass unet into DiffusionPipeline # pass unet into DiffusionPipeline
unet = UNet2DModel.from_pretrained(model_path) unet = UNet2DModel.from_pretrained(model_path, device_map="auto")
ddpm_from_hub_custom_model = DiffusionPipeline.from_pretrained(model_path, unet=unet, scheduler=scheduler) ddpm_from_hub_custom_model = DiffusionPipeline.from_pretrained(
model_path, unet=unet, scheduler=scheduler, device_map="auto"
)
ddpm_from_hub_custom_model.to(torch_device) ddpm_from_hub_custom_model.to(torch_device)
ddpm_from_hub_custom_model.set_progress_bar_config(disable=None) ddpm_from_hub_custom_model.set_progress_bar_config(disable=None)
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler) ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler, device_map="auto")
ddpm_from_hub.to(torch_device) ddpm_from_hub.to(torch_device)
ddpm_from_hub_custom_model.set_progress_bar_config(disable=None) ddpm_from_hub_custom_model.set_progress_bar_config(disable=None)
@ -415,7 +422,7 @@ class PipelineSlowTests(unittest.TestCase):
def test_output_format(self): def test_output_format(self):
model_path = "google/ddpm-cifar10-32" model_path = "google/ddpm-cifar10-32"
pipe = DDIMPipeline.from_pretrained(model_path) pipe = DDIMPipeline.from_pretrained(model_path, device_map="auto")
pipe.to(torch_device) pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
@ -437,7 +444,7 @@ class PipelineSlowTests(unittest.TestCase):
def test_ddpm_ddim_equality(self): def test_ddpm_ddim_equality(self):
model_id = "google/ddpm-cifar10-32" model_id = "google/ddpm-cifar10-32"
unet = UNet2DModel.from_pretrained(model_id) unet = UNet2DModel.from_pretrained(model_id, device_map="auto")
ddpm_scheduler = DDPMScheduler() ddpm_scheduler = DDPMScheduler()
ddim_scheduler = DDIMScheduler() ddim_scheduler = DDIMScheduler()
@ -461,7 +468,7 @@ class PipelineSlowTests(unittest.TestCase):
def test_ddpm_ddim_equality_batched(self): def test_ddpm_ddim_equality_batched(self):
model_id = "google/ddpm-cifar10-32" model_id = "google/ddpm-cifar10-32"
unet = UNet2DModel.from_pretrained(model_id) unet = UNet2DModel.from_pretrained(model_id, device_map="auto")
ddpm_scheduler = DDPMScheduler() ddpm_scheduler = DDPMScheduler()
ddim_scheduler = DDIMScheduler() ddim_scheduler = DDIMScheduler()