2023-12-30 15:20:30 -07:00
|
|
|
import types
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
import torch
|
|
|
|
|
2023-12-31 12:38:30 -07:00
|
|
|
from modules import torch_utils
|
2023-12-30 15:20:30 -07:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("wrapped", [True, False])
|
|
|
|
def test_get_param(wrapped):
|
|
|
|
mod = torch.nn.Linear(1, 1)
|
|
|
|
cpu = torch.device("cpu")
|
|
|
|
mod.to(dtype=torch.float16, device=cpu)
|
|
|
|
if wrapped:
|
|
|
|
# more or less how spandrel wraps a thing
|
|
|
|
mod = types.SimpleNamespace(model=mod)
|
2023-12-31 12:38:30 -07:00
|
|
|
p = torch_utils.get_param(mod)
|
2023-12-30 15:20:30 -07:00
|
|
|
assert p.dtype == torch.float16
|
|
|
|
assert p.device == cpu
|