此博文是修改https://blog.csdn.net/jiacong_wang/article/details/105631229
这位大大的博文而成的,自己根据自己的情况稍微加了点东西
要修改的地方有4处
1.修改网络第一层,把3通道改为1
法一:直接在定义网络的地方修改
self.conv1 = nn.Conv2d(1, self.in_channel, kernel_size=7, stride=2,padding=3, bias=False)
法二:在调用网络模型的地方修改
model = resnet50()
model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
model = model.to(device)
2.修改读取数据的方式
--------修改之前:
--------修改之后
train_transformer = transforms.Compose([
transforms.Grayscale(1), # 修改
transforms.RandomHorizontalFlip(0.5),
transforms.ToTensor(),
transforms.Normalize(0.485, 0.229, inplace=True), # 修改(-1,1)
])
- 修改方法
- 修改transform(图像预处理操作)
添加transforms.Grayscale(1),将图像转换为单通道图像(经实验,图像矩阵的数据并不会发生变化) - transforms.Normalize修改如下,第一个参数为mean,第二个参数为std,因为是单通道,所以进行Z-Score时仅需要对一个通道进行操作,所以mean和std只需要一个值就行
3.修改读取数据集部分(mydataset.py)
-----修改之前
-----修改之后
img = Image.open(path_img)
只需要把后面转RGB的部分去掉就行
4.因为加载预训练模型而修改网络
要加载预训练模型,第一层的权重参数肯定不能加载,则只需要把第一层的权重参数避开就行:
- 加载方法
net = resnet50()
# load pretrain weights
# download url: https://download.pytorch.org/models/resnet34-333f7ec4.pth
model_weight_path = "./resnet50-19c8e357.pth"
assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)
# 加载预训练模型并且把不需要的层去掉
pre_state_dict = torch.load(model_weight_path)
print("原模型", pre_state_dict.keys())
new_state_dict = {}
for k, v in net.state_dict().items(): # 遍历修改模型的各个层
print("新模型", k)
if k in pre_state_dict.keys() and k!= 'conv1.weight':
new_state_dict[k] = pre_state_dict[k] # 如果原模型的层也在新模型的层里面, 那新模型就加载原先训练好的权重
net.load_state_dict(new_state_dict, False)
# for param in net.parameters():
# param.requires_grad = False
# change fc layer structure
in_channel = net.fc.in_features
net.fc = nn.Linear(in_channel, 5)
net.to(device)