def init_weights(self):
# 先读取下载的预训练的键,读取模型的键
checkpoint = torch.load('F:/Code/pytorch-deeplab-dualattention/test/state_dict_73.98.pth')
state_dict = OrderedDict() # 对字典对象中的元素排序
# convert data_parallal to model 改变键的名字 更改名:将下载的预训练的键进行改名,if判断语句有很多个,因为结构有变化
i = 0
for key in checkpoint:
# 前24个
if i in range(0, 6):
# a = "backbone."
# b = a + key
# state_dict[b] = checkpoint[key]
state_dict[key] = checkpoint[key]
if i in range(6, 30):
# a = "backbone.layer1.0"
a = "layer1.0"
b = a + key[10:]
state_dict[b] = checkpoint[key]
if i in range(30, 72):
a = "layer1.1"
b = a + key[10:]
state_dict[b] = checkpoint[key]
if i in range(72, 96):
a = "layer1.2"
b = a + key[10:]
state_dict[b] = checkpoint[key]
if i in range(96, 142):
a = "layer2.0"
b = a + key[10:]
state_dict[b] = checkpoint[key]
if i in range(142, 170):
a = "layer2.1"
b = a + key[10:]
state_dict[b] = checkpoint[key]
if i in range(170, 212):
a = "layer3.0"
b = a + key[10:]
state_dict[b] = checkpoint[key]
if i in range(212, 236):
a = "layer3.1"
b = a + key[10:]
state_dict[b] = checkpoint[key]
if i in range(236, 260):
a = "layer3.2"
b = a + key[10:]
state_dict[b] = checkpoint[key]
if i in range(260, 284):
a = "layer3.3"
b = a + key[10:]
state_dict[b] = checkpoint[key]
if i in range(284, 324):
a = "layer3.4"
b = a + key[10:]
state_dict[b] = checkpoint[key]
if i in range(324, 352):
a = "layer3.5"
b = a + key[10:]
state_dict[b] = checkpoint[key]
if i in range(352, 398):
a = "layer4.0"
b = a + key[10:]
state_dict[b] = checkpoint[key]
if i in range(398, 422):
a = "layer4.1"
b = a + key[10:]
state_dict[b] = checkpoint[key]
if i in range(422, 450):
a = "layer4.2"
b = a + key[10:]
state_dict[b] = checkpoint[key]
if i in range(450, 474):
a = "layer4.3"
b = a + key[10:]
state_dict[b] = checkpoint[key]
if i in range(474, 502):
a = "layer4.4"
b = a + key[10:]
state_dict[b] = checkpoint[key]
i += 1
# check loaded parameters and created model parameters 去掉module字符
model_state_dict_ = self.state_dict()
model_state_dict = OrderedDict()
for key in model_state_dict_:
model_state_dict[key] = model_state_dict_[key]
# 检查权重格式 将不必要的键去掉
for key in state_dict:
if key in model_state_dict:
if state_dict[key].shape != model_state_dict[key].shape:
print('Skip loading parameter {}, required shape{}, loaded shape{}.'.format(
key, model_state_dict[key].shape, state_dict[key].shape))
state_dict[key] = model_state_dict[key]
else:
state_dict.pop(key)
print('Drop parameter {}.'.format(key))
for key in model_state_dict:
if key not in state_dict:
print('No param {}.'.format(key))
state_dict[key] = model_state_dict[key]
# 将权重的key与model的key统一
model_key = list(model_state_dict_.keys())
pretrained_key = list(state_dict.keys())
pre_state_dict = OrderedDict()
for k in range(len(model_key)):
pre_state_dict[model_key[k]] = state_dict[pretrained_key[k]]
self.load_state_dict(pre_state_dict, strict=True)
语义分割Deeplabev3plus更改预训练权重(GHOSTNET)
最新推荐文章于 2023-03-10 17:45:12 发布