from torchvision import models
import torch.nn as nn
import torch
from resnet import resnet50_modifued
'''任务一,1.将resnet作为backbone,修改为模型自己需要的模型
2.将resnet的预训练参数加载到自己的模型
'''
resnet50_modifued = resnet50_modifued()
new_weights_dict = resnet50_modifued.state_dict()
resnet = models.resnet50(weights = models.ResNeXt50_32X4D_Weights.IMAGENET1K_V2)
weights_dict = resnet.state_dict()
for k in weights_dict.keys():
if k in new_weights_dict.keys() and not k.startswith("fc"):
new_weights_dict[k] = weights_dict[k]
resnet50_modifued.load_state_dict(new_weights_dict)
'''任务二 冻结训练好的参数'''
params = []
train_layer = ["layer5", "conv_end", "bn_end"]
for name,param in resnet50_modifued.named_parameters():
if any(name.startswith(prefix)for prefix in train_layer):
print(name)
params.append(param)
else:
param.requires_grad = False
optimizer = torch.optim.SGD(params, lr = 0.001)