def load_state_dict(self, state_dict):
''' Self-write load state_dict '''
for name, param in state_dict.items():
if name in self._channel_dict:
if 'bn' in name:
param = param.unsqueeze(0).data
else:
param = param.data
try:
self._channel_dict[name].copy_(param)
except Exception:
raise RuntimeError('While copying the parameter named {}, '
'whose dimensions in the model are {} and '
'whose dimensions in the checkpoint are {}.'
.format(name, self._channel_dict[name].size(), param.size()))
refer-link : https://blog.csdn.net/a137376864/article/details/78654618