|
|
|
|
|
|
|
import os |
|
import sys |
|
|
|
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) |
|
sys.path.append(parent_dir) |
|
|
|
import torch |
|
from alias_free_activation.cuda import activation1d |
|
from activations import Snake |
|
|
|
|
|
def test_load_fused_kernels(): |
|
try: |
|
print("[Success] load_fused_kernels") |
|
except ImportError as e: |
|
print("[Fail] load_fused_kernels") |
|
raise e |
|
|
|
|
|
def test_anti_alias_activation(): |
|
data = torch.rand((10, 10, 200), device="cuda") |
|
|
|
|
|
fused_anti_alias_activation = activation1d.Activation1d( |
|
activation=Snake(10), fused=True |
|
).cuda() |
|
fused_activation_output = fused_anti_alias_activation(data) |
|
|
|
torch_anti_alias_activation = activation1d.Activation1d( |
|
activation=Snake(10), fused=False |
|
).cuda() |
|
torch_activation_output = torch_anti_alias_activation(data) |
|
|
|
test_result = (fused_activation_output - torch_activation_output).abs() |
|
|
|
while test_result.dim() != 1: |
|
test_result = test_result.mean(dim=-1) |
|
|
|
diff = test_result.mean(dim=-1) |
|
|
|
if diff <= 1e-3: |
|
print( |
|
f"\n[Success] test_fused_anti_alias_activation" |
|
f"\n > mean_difference={diff}" |
|
f"\n > fused_values={fused_activation_output[-1][-1][:].tolist()}" |
|
f"\n > torch_values={torch_activation_output[-1][-1][:].tolist()}" |
|
) |
|
else: |
|
print( |
|
f"\n[Fail] test_fused_anti_alias_activation" |
|
f"\n > mean_difference={diff}, " |
|
f"\n > fused_values={fused_activation_output[-1][-1][:].tolist()}, " |
|
f"\n > torch_values={torch_activation_output[-1][-1][:].tolist()}" |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
from alias_free_activation.cuda import load |
|
|
|
load.load() |
|
test_load_fused_kernels() |
|
test_anti_alias_activation() |
|
|