Fix some failing tests (#1041)
* up * up * up * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py * Apply suggestions from code review
This commit is contained in:
parent
d2d9764f35
commit
8d6487f3cb
|
@ -662,6 +662,8 @@ class LDMBertEncoder(LDMBertPreTrainedModel):
|
|||
|
||||
|
||||
class LDMBertModel(LDMBertPreTrainedModel):
|
||||
_no_split_modules = []
|
||||
|
||||
def __init__(self, config: LDMBertConfig):
|
||||
super().__init__(config)
|
||||
self.model = LDMBertEncoder(config)
|
||||
|
|
|
@ -208,7 +208,6 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
|||
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||
(nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
|
||||
if isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif isinstance(prompt, list):
|
||||
|
|
|
@ -740,7 +740,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
|||
|
||||
start_time = time.time()
|
||||
pipeline_normal_load = StableDiffusionPipeline.from_pretrained(
|
||||
pipeline_id, revision="fp16", torch_dtype=torch.float16, device_map="auto"
|
||||
pipeline_id, revision="fp16", torch_dtype=torch.float16
|
||||
)
|
||||
pipeline_normal_load.to(torch_device)
|
||||
normal_load_time = time.time() - start_time
|
||||
|
@ -761,9 +761,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
|||
pipeline_id = "CompVis/stable-diffusion-v1-4"
|
||||
prompt = "Andromeda galaxy in a bottle"
|
||||
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
pipeline_id, revision="fp16", torch_dtype=torch.float16, device_map="auto"
|
||||
)
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(pipeline_id, revision="fp16", torch_dtype=torch.float16)
|
||||
pipeline.enable_attention_slicing(1)
|
||||
pipeline.enable_sequential_cpu_offload()
|
||||
|
||||
|
|
|
@ -77,6 +77,7 @@ class CustomPipelineTests(unittest.TestCase):
|
|||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"google/ddpm-cifar10-32", custom_pipeline="hf-internal-testing/diffusers-dummy-pipeline"
|
||||
)
|
||||
pipeline = pipeline.to(torch_device)
|
||||
# NOTE that `"CustomPipeline"` is not a class that is defined in this library, but solely on the Hub
|
||||
# under https://huggingface.co/hf-internal-testing/diffusers-dummy-pipeline/blob/main/pipeline.py#L24
|
||||
assert pipeline.__class__.__name__ == "CustomPipeline"
|
||||
|
@ -85,6 +86,7 @@ class CustomPipelineTests(unittest.TestCase):
|
|||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"google/ddpm-cifar10-32", custom_pipeline="hf-internal-testing/diffusers-dummy-pipeline"
|
||||
)
|
||||
pipeline = pipeline.to(torch_device)
|
||||
images, output_str = pipeline(num_inference_steps=2, output_type="np")
|
||||
|
||||
assert images[0].shape == (1, 32, 32, 3)
|
||||
|
@ -96,6 +98,7 @@ class CustomPipelineTests(unittest.TestCase):
|
|||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"google/ddpm-cifar10-32", custom_pipeline=local_custom_pipeline_path
|
||||
)
|
||||
pipeline = pipeline.to(torch_device)
|
||||
images, output_str = pipeline(num_inference_steps=2, output_type="np")
|
||||
|
||||
assert pipeline.__class__.__name__ == "CustomLocalPipeline"
|
||||
|
@ -109,7 +112,7 @@ class CustomPipelineTests(unittest.TestCase):
|
|||
clip_model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
|
||||
|
||||
feature_extractor = CLIPFeatureExtractor.from_pretrained(clip_model_id, device_map="auto")
|
||||
clip_model = CLIPModel.from_pretrained(clip_model_id, torch_dtype=torch.float16, device_map="auto")
|
||||
clip_model = CLIPModel.from_pretrained(clip_model_id, torch_dtype=torch.float16)
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
|
@ -380,10 +383,11 @@ class PipelineSlowTests(unittest.TestCase):
|
|||
scheduler = DDPMScheduler(num_train_timesteps=10)
|
||||
|
||||
ddpm = DDPMPipeline.from_pretrained(model_path, scheduler=scheduler, device_map="auto")
|
||||
ddpm.to(torch_device)
|
||||
ddpm = ddpm.to(torch_device)
|
||||
ddpm.set_progress_bar_config(disable=None)
|
||||
|
||||
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler, device_map="auto")
|
||||
ddpm_from_hub.to(torch_device)
|
||||
ddpm_from_hub = ddpm_from_hub.to(torch_device)
|
||||
ddpm_from_hub.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
|
@ -404,11 +408,11 @@ class PipelineSlowTests(unittest.TestCase):
|
|||
ddpm_from_hub_custom_model = DiffusionPipeline.from_pretrained(
|
||||
model_path, unet=unet, scheduler=scheduler, device_map="auto"
|
||||
)
|
||||
ddpm_from_hub_custom_model.to(torch_device)
|
||||
ddpm_from_hub_custom_model = ddpm_from_hub_custom_model.to(torch_device)
|
||||
ddpm_from_hub_custom_model.set_progress_bar_config(disable=None)
|
||||
|
||||
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler, device_map="auto")
|
||||
ddpm_from_hub.to(torch_device)
|
||||
ddpm_from_hub = ddpm_from_hub.to(torch_device)
|
||||
ddpm_from_hub_custom_model.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
|
|
Loading…
Reference in New Issue