Spaces:
Build error
Build error
File size: 490 Bytes
ddc8a59 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
import numpy as np
import torch
from typing import Dict
def run_assertion(orig_pt_state_dict: Dict[str, torch.Tensor], pt_state_dict_from_tf: Dict[str, torch.Tensor]):
for k in orig_pt_state_dict:
try:
np.testing.assert_allclose(
orig_pt_state_dict[k].numpy(),
pt_state_dict_from_tf[k].numpy()
)
except:
raise ValueError("There are problems in the parameter population process. Cannot proceed :(") |