[Tests] Speed up slow tests (#1040)
* [Tests] Speed up slow tests * Up * up
This commit is contained in:
parent
a80480f0f2
commit
d2d9764f35
|
@ -86,7 +86,7 @@ class PipelineIntegrationTests(unittest.TestCase):
|
||||||
def test_dance_diffusion(self):
|
def test_dance_diffusion(self):
|
||||||
device = torch_device
|
device = torch_device
|
||||||
|
|
||||||
pipe = DanceDiffusionPipeline.from_pretrained("harmonai/maestro-150k")
|
pipe = DanceDiffusionPipeline.from_pretrained("harmonai/maestro-150k", device_map="auto")
|
||||||
pipe = pipe.to(device)
|
pipe = pipe.to(device)
|
||||||
pipe.set_progress_bar_config(disable=None)
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
@ -103,7 +103,9 @@ class PipelineIntegrationTests(unittest.TestCase):
|
||||||
def test_dance_diffusion_fp16(self):
|
def test_dance_diffusion_fp16(self):
|
||||||
device = torch_device
|
device = torch_device
|
||||||
|
|
||||||
pipe = DanceDiffusionPipeline.from_pretrained("harmonai/maestro-150k", torch_dtype=torch.float16)
|
pipe = DanceDiffusionPipeline.from_pretrained(
|
||||||
|
"harmonai/maestro-150k", torch_dtype=torch.float16, device_map="auto"
|
||||||
|
)
|
||||||
pipe = pipe.to(device)
|
pipe = pipe.to(device)
|
||||||
pipe.set_progress_bar_config(disable=None)
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
|
|
@ -78,7 +78,7 @@ class DDIMPipelineIntegrationTests(unittest.TestCase):
|
||||||
def test_inference_ema_bedroom(self):
|
def test_inference_ema_bedroom(self):
|
||||||
model_id = "google/ddpm-ema-bedroom-256"
|
model_id = "google/ddpm-ema-bedroom-256"
|
||||||
|
|
||||||
unet = UNet2DModel.from_pretrained(model_id)
|
unet = UNet2DModel.from_pretrained(model_id, device_map="auto")
|
||||||
scheduler = DDIMScheduler.from_config(model_id)
|
scheduler = DDIMScheduler.from_config(model_id)
|
||||||
|
|
||||||
ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
|
ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
|
||||||
|
@ -97,7 +97,7 @@ class DDIMPipelineIntegrationTests(unittest.TestCase):
|
||||||
def test_inference_cifar10(self):
|
def test_inference_cifar10(self):
|
||||||
model_id = "google/ddpm-cifar10-32"
|
model_id = "google/ddpm-cifar10-32"
|
||||||
|
|
||||||
unet = UNet2DModel.from_pretrained(model_id)
|
unet = UNet2DModel.from_pretrained(model_id, device_map="auto")
|
||||||
scheduler = DDIMScheduler()
|
scheduler = DDIMScheduler()
|
||||||
|
|
||||||
ddim = DDIMPipeline(unet=unet, scheduler=scheduler)
|
ddim = DDIMPipeline(unet=unet, scheduler=scheduler)
|
||||||
|
|
|
@ -38,7 +38,7 @@ class DDPMPipelineIntegrationTests(unittest.TestCase):
|
||||||
def test_inference_cifar10(self):
|
def test_inference_cifar10(self):
|
||||||
model_id = "google/ddpm-cifar10-32"
|
model_id = "google/ddpm-cifar10-32"
|
||||||
|
|
||||||
unet = UNet2DModel.from_pretrained(model_id)
|
unet = UNet2DModel.from_pretrained(model_id, device_map="auto")
|
||||||
scheduler = DDPMScheduler.from_config(model_id)
|
scheduler = DDPMScheduler.from_config(model_id)
|
||||||
|
|
||||||
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
|
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
|
||||||
|
|
|
@ -70,7 +70,7 @@ class KarrasVePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||||
class KarrasVePipelineIntegrationTests(unittest.TestCase):
|
class KarrasVePipelineIntegrationTests(unittest.TestCase):
|
||||||
def test_inference(self):
|
def test_inference(self):
|
||||||
model_id = "google/ncsnpp-celebahq-256"
|
model_id = "google/ncsnpp-celebahq-256"
|
||||||
model = UNet2DModel.from_pretrained(model_id)
|
model = UNet2DModel.from_pretrained(model_id, device_map="auto")
|
||||||
scheduler = KarrasVeScheduler()
|
scheduler = KarrasVeScheduler()
|
||||||
|
|
||||||
pipe = KarrasVePipeline(unet=model, scheduler=scheduler)
|
pipe = KarrasVePipeline(unet=model, scheduler=scheduler)
|
||||||
|
|
|
@ -121,7 +121,7 @@ class LDMTextToImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||||
@require_torch
|
@require_torch
|
||||||
class LDMTextToImagePipelineIntegrationTests(unittest.TestCase):
|
class LDMTextToImagePipelineIntegrationTests(unittest.TestCase):
|
||||||
def test_inference_text2img(self):
|
def test_inference_text2img(self):
|
||||||
ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256")
|
ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256", device_map="auto")
|
||||||
ldm.to(torch_device)
|
ldm.to(torch_device)
|
||||||
ldm.set_progress_bar_config(disable=None)
|
ldm.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
@ -138,7 +138,7 @@ class LDMTextToImagePipelineIntegrationTests(unittest.TestCase):
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
def test_inference_text2img_fast(self):
|
def test_inference_text2img_fast(self):
|
||||||
ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256")
|
ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256", device_map="auto")
|
||||||
ldm.to(torch_device)
|
ldm.to(torch_device)
|
||||||
ldm.set_progress_bar_config(disable=None)
|
ldm.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
|
|
@ -71,7 +71,7 @@ class PNDMPipelineIntegrationTests(unittest.TestCase):
|
||||||
def test_inference_cifar10(self):
|
def test_inference_cifar10(self):
|
||||||
model_id = "google/ddpm-cifar10-32"
|
model_id = "google/ddpm-cifar10-32"
|
||||||
|
|
||||||
unet = UNet2DModel.from_pretrained(model_id)
|
unet = UNet2DModel.from_pretrained(model_id, device_map="auto")
|
||||||
scheduler = PNDMScheduler()
|
scheduler = PNDMScheduler()
|
||||||
|
|
||||||
pndm = PNDMPipeline(unet=unet, scheduler=scheduler)
|
pndm = PNDMPipeline(unet=unet, scheduler=scheduler)
|
||||||
|
|
|
@ -72,7 +72,7 @@ class ScoreSdeVeipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||||
class ScoreSdeVePipelineIntegrationTests(unittest.TestCase):
|
class ScoreSdeVePipelineIntegrationTests(unittest.TestCase):
|
||||||
def test_inference(self):
|
def test_inference(self):
|
||||||
model_id = "google/ncsnpp-church-256"
|
model_id = "google/ncsnpp-church-256"
|
||||||
model = UNet2DModel.from_pretrained(model_id)
|
model = UNet2DModel.from_pretrained(model_id, device_map="auto")
|
||||||
|
|
||||||
scheduler = ScoreSdeVeScheduler.from_config(model_id)
|
scheduler = ScoreSdeVeScheduler.from_config(model_id)
|
||||||
|
|
||||||
|
|
|
@ -528,7 +528,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||||
|
|
||||||
def test_stable_diffusion(self):
|
def test_stable_diffusion(self):
|
||||||
# make sure here that pndm scheduler skips prk
|
# make sure here that pndm scheduler skips prk
|
||||||
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1")
|
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1", device_map="auto")
|
||||||
sd_pipe = sd_pipe.to(torch_device)
|
sd_pipe = sd_pipe.to(torch_device)
|
||||||
sd_pipe.set_progress_bar_config(disable=None)
|
sd_pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
@ -548,7 +548,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
def test_stable_diffusion_fast_ddim(self):
|
def test_stable_diffusion_fast_ddim(self):
|
||||||
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1")
|
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1", device_map="auto")
|
||||||
sd_pipe = sd_pipe.to(torch_device)
|
sd_pipe = sd_pipe.to(torch_device)
|
||||||
sd_pipe.set_progress_bar_config(disable=None)
|
sd_pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
@ -576,7 +576,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||||
|
|
||||||
def test_lms_stable_diffusion_pipeline(self):
|
def test_lms_stable_diffusion_pipeline(self):
|
||||||
model_id = "CompVis/stable-diffusion-v1-1"
|
model_id = "CompVis/stable-diffusion-v1-1"
|
||||||
pipe = StableDiffusionPipeline.from_pretrained(model_id).to(torch_device)
|
pipe = StableDiffusionPipeline.from_pretrained(model_id, device_map="auto").to(torch_device)
|
||||||
pipe.set_progress_bar_config(disable=None)
|
pipe.set_progress_bar_config(disable=None)
|
||||||
scheduler = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler")
|
scheduler = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler")
|
||||||
pipe.scheduler = scheduler
|
pipe.scheduler = scheduler
|
||||||
|
@ -595,9 +595,10 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||||
def test_stable_diffusion_memory_chunking(self):
|
def test_stable_diffusion_memory_chunking(self):
|
||||||
torch.cuda.reset_peak_memory_stats()
|
torch.cuda.reset_peak_memory_stats()
|
||||||
model_id = "CompVis/stable-diffusion-v1-4"
|
model_id = "CompVis/stable-diffusion-v1-4"
|
||||||
pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16).to(
|
pipe = StableDiffusionPipeline.from_pretrained(
|
||||||
torch_device
|
model_id, revision="fp16", torch_dtype=torch.float16, device_map="auto"
|
||||||
)
|
)
|
||||||
|
pipe.to(torch_device)
|
||||||
pipe.set_progress_bar_config(disable=None)
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
prompt = "a photograph of an astronaut riding a horse"
|
prompt = "a photograph of an astronaut riding a horse"
|
||||||
|
@ -633,9 +634,10 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||||
def test_stable_diffusion_text2img_pipeline_fp16(self):
|
def test_stable_diffusion_text2img_pipeline_fp16(self):
|
||||||
torch.cuda.reset_peak_memory_stats()
|
torch.cuda.reset_peak_memory_stats()
|
||||||
model_id = "CompVis/stable-diffusion-v1-4"
|
model_id = "CompVis/stable-diffusion-v1-4"
|
||||||
pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16).to(
|
pipe = StableDiffusionPipeline.from_pretrained(
|
||||||
torch_device
|
model_id, revision="fp16", device_map="auto", torch_dtype=torch.float16
|
||||||
)
|
)
|
||||||
|
pipe = pipe.to(torch_device)
|
||||||
pipe.set_progress_bar_config(disable=None)
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
prompt = "a photograph of an astronaut riding a horse"
|
prompt = "a photograph of an astronaut riding a horse"
|
||||||
|
@ -670,6 +672,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||||
pipe = StableDiffusionPipeline.from_pretrained(
|
pipe = StableDiffusionPipeline.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
safety_checker=None,
|
safety_checker=None,
|
||||||
|
device_map="auto",
|
||||||
)
|
)
|
||||||
pipe.to(torch_device)
|
pipe.to(torch_device)
|
||||||
pipe.set_progress_bar_config(disable=None)
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
@ -711,7 +714,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||||
test_callback_fn.has_been_called = False
|
test_callback_fn.has_been_called = False
|
||||||
|
|
||||||
pipe = StableDiffusionPipeline.from_pretrained(
|
pipe = StableDiffusionPipeline.from_pretrained(
|
||||||
"CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16
|
"CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, device_map="auto"
|
||||||
)
|
)
|
||||||
pipe = pipe.to(torch_device)
|
pipe = pipe.to(torch_device)
|
||||||
pipe.set_progress_bar_config(disable=None)
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
@ -737,7 +740,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
pipeline_normal_load = StableDiffusionPipeline.from_pretrained(
|
pipeline_normal_load = StableDiffusionPipeline.from_pretrained(
|
||||||
pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True
|
pipeline_id, revision="fp16", torch_dtype=torch.float16, device_map="auto"
|
||||||
)
|
)
|
||||||
pipeline_normal_load.to(torch_device)
|
pipeline_normal_load.to(torch_device)
|
||||||
normal_load_time = time.time() - start_time
|
normal_load_time = time.time() - start_time
|
||||||
|
@ -758,7 +761,9 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||||
pipeline_id = "CompVis/stable-diffusion-v1-4"
|
pipeline_id = "CompVis/stable-diffusion-v1-4"
|
||||||
prompt = "Andromeda galaxy in a bottle"
|
prompt = "Andromeda galaxy in a bottle"
|
||||||
|
|
||||||
pipeline = StableDiffusionPipeline.from_pretrained(pipeline_id, revision="fp16", torch_dtype=torch.float16)
|
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||||
|
pipeline_id, revision="fp16", torch_dtype=torch.float16, device_map="auto"
|
||||||
|
)
|
||||||
pipeline.enable_attention_slicing(1)
|
pipeline.enable_attention_slicing(1)
|
||||||
pipeline.enable_sequential_cpu_offload()
|
pipeline.enable_sequential_cpu_offload()
|
||||||
|
|
||||||
|
|
|
@ -488,6 +488,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
|
||||||
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
|
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
safety_checker=None,
|
safety_checker=None,
|
||||||
|
device_map="auto",
|
||||||
)
|
)
|
||||||
pipe.to(torch_device)
|
pipe.to(torch_device)
|
||||||
pipe.set_progress_bar_config(disable=None)
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
@ -529,6 +530,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
|
||||||
model_id,
|
model_id,
|
||||||
scheduler=lms,
|
scheduler=lms,
|
||||||
safety_checker=None,
|
safety_checker=None,
|
||||||
|
device_map="auto",
|
||||||
)
|
)
|
||||||
pipe.to(torch_device)
|
pipe.to(torch_device)
|
||||||
pipe.set_progress_bar_config(disable=None)
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
@ -580,7 +582,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
|
||||||
init_image = init_image.resize((768, 512))
|
init_image = init_image.resize((768, 512))
|
||||||
|
|
||||||
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
|
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
|
||||||
"CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16
|
"CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, device_map="auto"
|
||||||
)
|
)
|
||||||
pipe.to(torch_device)
|
pipe.to(torch_device)
|
||||||
pipe.set_progress_bar_config(disable=None)
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
|
@ -288,6 +288,7 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
|
||||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
safety_checker=None,
|
safety_checker=None,
|
||||||
|
device_map="auto",
|
||||||
)
|
)
|
||||||
pipe.to(torch_device)
|
pipe.to(torch_device)
|
||||||
pipe.set_progress_bar_config(disable=None)
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
@ -329,6 +330,7 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
|
||||||
revision="fp16",
|
revision="fp16",
|
||||||
torch_dtype=torch.float16,
|
torch_dtype=torch.float16,
|
||||||
safety_checker=None,
|
safety_checker=None,
|
||||||
|
device_map="auto",
|
||||||
)
|
)
|
||||||
pipe.to(torch_device)
|
pipe.to(torch_device)
|
||||||
pipe.set_progress_bar_config(disable=None)
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
@ -366,7 +368,9 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
|
||||||
|
|
||||||
pndm = PNDMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True)
|
pndm = PNDMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True)
|
||||||
model_id = "runwayml/stable-diffusion-inpainting"
|
model_id = "runwayml/stable-diffusion-inpainting"
|
||||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None, scheduler=pndm)
|
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
||||||
|
model_id, safety_checker=None, scheduler=pndm, device_map="auto"
|
||||||
|
)
|
||||||
pipe.to(torch_device)
|
pipe.to(torch_device)
|
||||||
pipe.set_progress_bar_config(disable=None)
|
pipe.set_progress_bar_config(disable=None)
|
||||||
pipe.enable_attention_slicing()
|
pipe.enable_attention_slicing()
|
||||||
|
|
|
@ -368,6 +368,7 @@ class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase):
|
||||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
safety_checker=None,
|
safety_checker=None,
|
||||||
|
device_map="auto",
|
||||||
)
|
)
|
||||||
pipe.to(torch_device)
|
pipe.to(torch_device)
|
||||||
pipe.set_progress_bar_config(disable=None)
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
@ -413,6 +414,7 @@ class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase):
|
||||||
model_id,
|
model_id,
|
||||||
scheduler=lms,
|
scheduler=lms,
|
||||||
safety_checker=None,
|
safety_checker=None,
|
||||||
|
device_map="auto",
|
||||||
)
|
)
|
||||||
pipe.to(torch_device)
|
pipe.to(torch_device)
|
||||||
pipe.set_progress_bar_config(disable=None)
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
@ -469,7 +471,7 @@ class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
||||||
"CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16
|
"CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, device_map="auto"
|
||||||
)
|
)
|
||||||
pipe.to(torch_device)
|
pipe.to(torch_device)
|
||||||
pipe.set_progress_bar_config(disable=None)
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
|
@ -108,8 +108,8 @@ class CustomPipelineTests(unittest.TestCase):
|
||||||
def test_load_pipeline_from_git(self):
|
def test_load_pipeline_from_git(self):
|
||||||
clip_model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
|
clip_model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
|
||||||
|
|
||||||
feature_extractor = CLIPFeatureExtractor.from_pretrained(clip_model_id)
|
feature_extractor = CLIPFeatureExtractor.from_pretrained(clip_model_id, device_map="auto")
|
||||||
clip_model = CLIPModel.from_pretrained(clip_model_id, torch_dtype=torch.float16)
|
clip_model = CLIPModel.from_pretrained(clip_model_id, torch_dtype=torch.float16, device_map="auto")
|
||||||
|
|
||||||
pipeline = DiffusionPipeline.from_pretrained(
|
pipeline = DiffusionPipeline.from_pretrained(
|
||||||
"CompVis/stable-diffusion-v1-4",
|
"CompVis/stable-diffusion-v1-4",
|
||||||
|
@ -118,6 +118,7 @@ class CustomPipelineTests(unittest.TestCase):
|
||||||
feature_extractor=feature_extractor,
|
feature_extractor=feature_extractor,
|
||||||
torch_dtype=torch.float16,
|
torch_dtype=torch.float16,
|
||||||
revision="fp16",
|
revision="fp16",
|
||||||
|
device_map="auto",
|
||||||
)
|
)
|
||||||
pipeline.enable_attention_slicing()
|
pipeline.enable_attention_slicing()
|
||||||
pipeline = pipeline.to(torch_device)
|
pipeline = pipeline.to(torch_device)
|
||||||
|
@ -312,7 +313,9 @@ class PipelineSlowTests(unittest.TestCase):
|
||||||
def test_smart_download(self):
|
def test_smart_download(self):
|
||||||
model_id = "hf-internal-testing/unet-pipeline-dummy"
|
model_id = "hf-internal-testing/unet-pipeline-dummy"
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
_ = DiffusionPipeline.from_pretrained(model_id, cache_dir=tmpdirname, force_download=True)
|
_ = DiffusionPipeline.from_pretrained(
|
||||||
|
model_id, cache_dir=tmpdirname, force_download=True, device_map="auto"
|
||||||
|
)
|
||||||
local_repo_name = "--".join(["models"] + model_id.split("/"))
|
local_repo_name = "--".join(["models"] + model_id.split("/"))
|
||||||
snapshot_dir = os.path.join(tmpdirname, local_repo_name, "snapshots")
|
snapshot_dir = os.path.join(tmpdirname, local_repo_name, "snapshots")
|
||||||
snapshot_dir = os.path.join(snapshot_dir, os.listdir(snapshot_dir)[0])
|
snapshot_dir = os.path.join(snapshot_dir, os.listdir(snapshot_dir)[0])
|
||||||
|
@ -335,7 +338,9 @@ class PipelineSlowTests(unittest.TestCase):
|
||||||
logger = logging.get_logger("diffusers.pipeline_utils")
|
logger = logging.get_logger("diffusers.pipeline_utils")
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
with CaptureLogger(logger) as cap_logger:
|
with CaptureLogger(logger) as cap_logger:
|
||||||
DiffusionPipeline.from_pretrained(model_id, not_used=True, cache_dir=tmpdirname, force_download=True)
|
DiffusionPipeline.from_pretrained(
|
||||||
|
model_id, not_used=True, cache_dir=tmpdirname, force_download=True, device_map="auto"
|
||||||
|
)
|
||||||
|
|
||||||
assert cap_logger.out == "Keyword arguments {'not_used': True} not recognized.\n"
|
assert cap_logger.out == "Keyword arguments {'not_used': True} not recognized.\n"
|
||||||
|
|
||||||
|
@ -358,7 +363,7 @@ class PipelineSlowTests(unittest.TestCase):
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
ddpm.save_pretrained(tmpdirname)
|
ddpm.save_pretrained(tmpdirname)
|
||||||
new_ddpm = DDPMPipeline.from_pretrained(tmpdirname)
|
new_ddpm = DDPMPipeline.from_pretrained(tmpdirname, device_map="auto")
|
||||||
new_ddpm.to(torch_device)
|
new_ddpm.to(torch_device)
|
||||||
|
|
||||||
generator = torch.manual_seed(0)
|
generator = torch.manual_seed(0)
|
||||||
|
@ -374,10 +379,10 @@ class PipelineSlowTests(unittest.TestCase):
|
||||||
|
|
||||||
scheduler = DDPMScheduler(num_train_timesteps=10)
|
scheduler = DDPMScheduler(num_train_timesteps=10)
|
||||||
|
|
||||||
ddpm = DDPMPipeline.from_pretrained(model_path, scheduler=scheduler)
|
ddpm = DDPMPipeline.from_pretrained(model_path, scheduler=scheduler, device_map="auto")
|
||||||
ddpm.to(torch_device)
|
ddpm.to(torch_device)
|
||||||
ddpm.set_progress_bar_config(disable=None)
|
ddpm.set_progress_bar_config(disable=None)
|
||||||
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler)
|
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler, device_map="auto")
|
||||||
ddpm_from_hub.to(torch_device)
|
ddpm_from_hub.to(torch_device)
|
||||||
ddpm_from_hub.set_progress_bar_config(disable=None)
|
ddpm_from_hub.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
@ -395,12 +400,14 @@ class PipelineSlowTests(unittest.TestCase):
|
||||||
scheduler = DDPMScheduler(num_train_timesteps=10)
|
scheduler = DDPMScheduler(num_train_timesteps=10)
|
||||||
|
|
||||||
# pass unet into DiffusionPipeline
|
# pass unet into DiffusionPipeline
|
||||||
unet = UNet2DModel.from_pretrained(model_path)
|
unet = UNet2DModel.from_pretrained(model_path, device_map="auto")
|
||||||
ddpm_from_hub_custom_model = DiffusionPipeline.from_pretrained(model_path, unet=unet, scheduler=scheduler)
|
ddpm_from_hub_custom_model = DiffusionPipeline.from_pretrained(
|
||||||
|
model_path, unet=unet, scheduler=scheduler, device_map="auto"
|
||||||
|
)
|
||||||
ddpm_from_hub_custom_model.to(torch_device)
|
ddpm_from_hub_custom_model.to(torch_device)
|
||||||
ddpm_from_hub_custom_model.set_progress_bar_config(disable=None)
|
ddpm_from_hub_custom_model.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler)
|
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler, device_map="auto")
|
||||||
ddpm_from_hub.to(torch_device)
|
ddpm_from_hub.to(torch_device)
|
||||||
ddpm_from_hub_custom_model.set_progress_bar_config(disable=None)
|
ddpm_from_hub_custom_model.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
@ -415,7 +422,7 @@ class PipelineSlowTests(unittest.TestCase):
|
||||||
def test_output_format(self):
|
def test_output_format(self):
|
||||||
model_path = "google/ddpm-cifar10-32"
|
model_path = "google/ddpm-cifar10-32"
|
||||||
|
|
||||||
pipe = DDIMPipeline.from_pretrained(model_path)
|
pipe = DDIMPipeline.from_pretrained(model_path, device_map="auto")
|
||||||
pipe.to(torch_device)
|
pipe.to(torch_device)
|
||||||
pipe.set_progress_bar_config(disable=None)
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
@ -437,7 +444,7 @@ class PipelineSlowTests(unittest.TestCase):
|
||||||
def test_ddpm_ddim_equality(self):
|
def test_ddpm_ddim_equality(self):
|
||||||
model_id = "google/ddpm-cifar10-32"
|
model_id = "google/ddpm-cifar10-32"
|
||||||
|
|
||||||
unet = UNet2DModel.from_pretrained(model_id)
|
unet = UNet2DModel.from_pretrained(model_id, device_map="auto")
|
||||||
ddpm_scheduler = DDPMScheduler()
|
ddpm_scheduler = DDPMScheduler()
|
||||||
ddim_scheduler = DDIMScheduler()
|
ddim_scheduler = DDIMScheduler()
|
||||||
|
|
||||||
|
@ -461,7 +468,7 @@ class PipelineSlowTests(unittest.TestCase):
|
||||||
def test_ddpm_ddim_equality_batched(self):
|
def test_ddpm_ddim_equality_batched(self):
|
||||||
model_id = "google/ddpm-cifar10-32"
|
model_id = "google/ddpm-cifar10-32"
|
||||||
|
|
||||||
unet = UNet2DModel.from_pretrained(model_id)
|
unet = UNet2DModel.from_pretrained(model_id, device_map="auto")
|
||||||
ddpm_scheduler = DDPMScheduler()
|
ddpm_scheduler = DDPMScheduler()
|
||||||
ddim_scheduler = DDIMScheduler()
|
ddim_scheduler = DDIMScheduler()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue