相对应pdf也已上传
main函数
if config['name'] is None:
设置网络设备的状态,或是显示目前的设置。
将所有参数放到config中
print('%s: %s' % (key, config[key]))
将所有参数打印出来
打开models/%s/config.yml
if config['loss'] == 'BCEWithLogitsLoss':
定义损失函数
cudnn.benchmark = True
让内置的 cuDNN 的 auto-tuner 自动寻找最适合当前配置的高效算法,来达到优化运行效率的问题。
model = archs.__dict__[config['arch']](config['num_classes'],config['input_channels'],config['deep_supervision'])
创建模型
进入vgg基本计算单元
先relu层→卷积层(3个)→batch(middle_channels=32)→卷积层→batch(out_channels=32)
self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0])
计算各卷积层数值
0_0(32,32,32)
1_0(32,64,64)
2_0(64,128,128)
3_0(128,256,256)
4_0(256,512,512)
0_1(96,32,32)
1_1(192,64,64)
2_1(384,128,128)
3_1(768,256,256)
0_2(128,32,32)
1_2(256,64,64)
2_2(512,128,128)
0_3(160,32,32)
1_3(320,64,64)
0_4(192,32,32)
if self.deep_supervision:
self.final1 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
全监督层
若无则为final
elif config['optimizer'] == 'SGD':
使用sgd优化器
params = filter(lambda p: p.requires_grad, model.parameters())
过滤序列,过滤掉不符合条件的元素,返回由符合条件元素组成的新列表。
filter(function, iterable)
参数
function -- 判断函数。
iterable -- 可迭代对象。
返回值
返回列表。
A.Resize(config['input_h'], config['input_w']),
将输入图按规定长宽裁减
train_transform = Compose([
训练集,数据增强操作
随机角度旋转
图像翻转
执行机率
A.RandomRotate90(),
transforms.Flip(),
OneOf([
transforms.HueSaturationValue(),
transforms.RandomBrightness(),
transforms.RandomContrast(),
], p=1)
A.Resize(config['input_h'], config['input_w']),
transforms.Normalize(),
val_transform = Compose([
验证集
A.Resize(config['input_h'], config['input_w']),
transforms.Normalize(),
train_dataset = Dataset(
训练集数据设置
img_ids=train_img_ids,
img_dir=os.path.join('inputs', config['dataset'], 'images'),
mask_dir=os.path.join('inputs', config['dataset'], 'masks'),
img_ext=config['img_ext'],
mask_ext=config['mask_ext'],
num_classes=config['num_classes'],
transform=train_transform)
进入dataset.py函数
val_dataset = Dataset(
验证集数据设置
img_ids=val_img_ids,
img_dir=os.path.join('inputs', config['dataset'], 'images'),
mask_dir=os.path.join('inputs', config['dataset'], 'masks'),
img_ext=config['img_ext'],
mask_ext=config['mask_ext'],
num_classes=config['num_classes'],
transform=val_transform)
进入dataset.py函数运行
self.img_ids = img_ids
self.img_dir = img_dir
self.mask_dir = mask_dir
self.img_ext = img_ext
self.mask_ext = mask_ext
self.num_classes = num_classes
self.transform = transform
train_loader = torch.utils.data.DataLoader(
log = OrderedDict([
日志
best_iou = 0
trigger = 0
epoch
parse_args()函数
parser = argparse.ArgumentParser()
创建解释器
使用 argparse 的第一步是创建一个 ArgumentParser 对象。
ArgumentParser 对象包含将命令行解析成 Python 数据类型所需的全部信息。
parser.add_argument( )
添加参数
给一个 ArgumentParser 添加程序参数信息是通过调用 add_argument() 方法完成的。
网络名称=arch+timestamp
epochs=100
--batch_size=8
--arch网络架构-=NestedUNet
深度监督--deep_supervision=false
输入渠道--input_channels=3
数量类别 number of classes=1
图像宽度 image width=96
图像高度 image height=96
损失函数loss = BCEDiceLoss
数据设置--dataset=dsb2018_96
外部文件图像--img_ext=.png
外部文件设置--mask_ext=.png
优化器--optimizer=SGD
学习率--learning_rate=1e-3
动量--momentum=0.9
宽度腐蚀量--weight_decay=1e-4
--nesterov=false
调度程序--scheduler=CosineAnneaLingLR
最小学习率--min_lr=1e-5
倍数--factor=0.1
容忍度--patience=2
重要事件--milestones=1,2
--gamma=2/3
提前停止--early_stopping=-1
网络数量--num_workers=0
config = parser.parse_args()
解析参数
ArgumentParser 通过 parse_args() 方法解析参数。
train()函数
avg_meters = {'loss': AverageMeter(),
'iou': AverageMeter()}
计算epoch平均损失
model.train()
进入model函数
input, target
使用前向传播预测值
二分类计算损失值
计算iou得分
梯度清0
反向传播
记录损失值和iou
返回loss和iou的平均损失
进入validate()函数
validate()函数
avg_meters = {'loss': AverageMeter(),'iou': AverageMeter()}
计算epoch平均损失
model.eval()
#切换评估模型
if config['deep_supervision']:outputs = model(input)
全监督 输出层loss=0
avg_meters