Spaces:
No application file
No application file
import types | |
import pytest | |
import torch | |
from modules import torch_utils | |
def test_get_param(wrapped): | |
mod = torch.nn.Linear(1, 1) | |
cpu = torch.device("cpu") | |
mod.to(dtype=torch.float16, device=cpu) | |
if wrapped: | |
# more or less how spandrel wraps a thing | |
mod = types.SimpleNamespace(model=mod) | |
p = torch_utils.get_param(mod) | |
assert p.dtype == torch.float16 | |
assert p.device == cpu | |