img_list格式如下
E:\...\3.nrrd E:\...\3.nrrd 0
E:\...\4.nrrd E:\...\4.nrrd 1
训练代码
import torch
from torch import nn
import os
import numpy as np
from torch.utils.data import Dataset
from scipy import ndimage
from torch import optim
from torch.utils.data import DataLoader
import time
import logging
import nrrd
img_list = 'data/train.txt' # type=str, help='Path for image list file'
pretrain_path = 'pretrain/resnet_50.pth' # type=str, help='Path for pretrained model.'
save_folder = "trails/models/Resnet50"
total_epochs = 20 # type=int, help='Number of total epochs to run'
save_intervals = 10 # type=int, help='Interation for saving model'
learning_rate = 0.001 # set to 0.001 when finetune, type=float, help= 'Initial learning rate (divided by 10 while training by lr scheduler)'
new_layer_names = ['conv_cls'] # type=list, help='New layer except for backbone'
batch_size = 1 # type=int, help='Batch Size'
input_D = 56 # type=int, help='Input size of depth'
input_H = 448 # type=int, help='Input size of height'
input_W = 448 # type=int, help='Input size of width'
torch.manual_seed(1)
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm3d(planes)
self.conv2 = nn.Conv3d(planes, planes, kernel_size=3, stride=stride, dilation=dilation, padding=dilation, bias=False)
self.bn2 = nn.BatchNorm3d(planes)
self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm3d(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
self.dilation = dilation
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, layers, input_D, input_H, input_W):
self.inplanes = 64
super(ResNet, self).__init__()
self.conv1 = nn.Conv3d(1, 64, kernel_size=7, stride=(2, 2, 2), padding=(3, 3, 3), bias=False)
self.bn1 = nn.BatchNorm3d(64) # conv1的输出维度
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1) # H/2,W/2。C不变
self.layer1 = self._make_layer(block, 64, layers[0]) # H,W不变。downsample控制的shortcut,out_channel=64x4=256
self.layer2 = self._make_layer(block, 128, layers[1], stride=2) # H/2, W/2。downsample控制的shortcut,out_channel=128x4=512
self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2) # H/2, W/2。downsample控制的shortcut,out_channel=256x4=1024
self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4) # H/2, W/2。downsample控制的shortcut,out_channel=512x4=2048
self.conv_cls = nn.Sequential(
nn.AdaptiveMaxPool3d(output_size=(1, 1, 1)),
nn.Flatten(start_dim=1),
nn.Dropout(0.1),
nn.Linear(512 * block.expansion, 1)
)
for m in self.modules():
if isinstance(m, nn.Conv3d):
m.weight = nn.init.kaiming_normal_(m.weight, mode='fan_out')
elif isinstance(m, nn.BatchNorm3d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv3d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm3d(planes * block.expansion))
layers = []
layers.append(block(self.inplanes, planes, stride=stride, dilation=dilation, downsample=downsample))
self.inplanes = planes * block.expansion # 在下一次调用_make_layer函数的时候,self.in_channel已经x4
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, dilation=dilation))
return nn.Sequential(*layers) # '*'的作用是将list转换为非关键字参数传入
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.conv_cls(x)
x = torch.sigmoid_(x)