stable-diffusion-webui/test/test_torch_utils.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

20 lines
473 B
Python
Raw Normal View History

import types
import pytest
import torch
2023-12-31 12:38:30 -07:00
from modules import torch_utils
@pytest.mark.parametrize("wrapped", [True, False])
def test_get_param(wrapped):
mod = torch.nn.Linear(1, 1)
cpu = torch.device("cpu")
mod.to(dtype=torch.float16, device=cpu)
if wrapped:
# more or less how spandrel wraps a thing
mod = types.SimpleNamespace(model=mod)
2023-12-31 12:38:30 -07:00
p = torch_utils.get_param(mod)
assert p.dtype == torch.float16
assert p.device == cpu