Spaces:
Build error
Build error
from typing import Dict | |
import numpy as np | |
import torch | |
def run_assertion( | |
orig_pt_state_dict: Dict[str, torch.Tensor], | |
pt_state_dict_from_tf: Dict[str, torch.Tensor], | |
): | |
for k in orig_pt_state_dict: | |
try: | |
np.testing.assert_allclose( | |
orig_pt_state_dict[k].numpy(), pt_state_dict_from_tf[k].numpy() | |
) | |
except: | |
raise ValueError( | |
"There are problems in the parameter population process. Cannot proceed :(" | |
) | |