Spaces:
Runtime error
Runtime error
import collections | |
def load_state(net, checkpoint): | |
source_state = checkpoint['state_dict'] | |
target_state = net.state_dict() | |
new_target_state = collections.OrderedDict() | |
for target_key, target_value in target_state.items(): | |
if target_key in source_state and source_state[target_key].size() == target_state[target_key].size(): | |
new_target_state[target_key] = source_state[target_key] | |
else: | |
new_target_state[target_key] = target_state[target_key] | |
print('[WARNING] Not found pre-trained parameters for {}'.format(target_key)) | |
net.load_state_dict(new_target_state) | |
def load_from_mobilenet(net, checkpoint): | |
source_state = checkpoint['state_dict'] | |
target_state = net.state_dict() | |
new_target_state = collections.OrderedDict() | |
for target_key, target_value in target_state.items(): | |
k = target_key | |
if k.find('model') != -1: | |
k = k.replace('model', 'module.model') | |
if k in source_state and source_state[k].size() == target_state[target_key].size(): | |
new_target_state[target_key] = source_state[k] | |
else: | |
new_target_state[target_key] = target_state[target_key] | |
print('[WARNING] Not found pre-trained parameters for {}'.format(target_key)) | |
net.load_state_dict(new_target_state) | |