def get_args():
parser = argparse.ArgumentParser(description=‘Train the UNet on images and target masks’)
parser.add_argument(‘–epochs’, ‘-e’, metavar=‘E’, type=int, default=300, help=‘Number of epochs’)
parser.add_argument(‘–batch-size’, ‘-b’, dest=‘batch_size’, metavar=‘B’, type=int, default=16, help=‘Batch size’)
parser.add_argument(‘–learning-rate’, ‘-l’, metavar=‘LR’, type=float, default=0.001,
help=‘Learning rate’, dest=‘lr’)
parser.add_argument(‘–load’, ‘-f’, type=str, default=False, help=‘Load model from a .pth file’)
parser.add_argument(‘–scale’, ‘-s’, type=float, default=0.5, help=‘Downscaling factor of the images’)
parser.add_argument(‘–validation’, ‘-v’, dest=‘val’, type=float, default=10.0,
help=‘Percent of the data that is used as validation (0-100)’)
parser.add_argument(‘–amp’, action=‘store_true’, default=False, help=‘Use mixed precision’)
return parser.parse_args()
epochs:epoch的个数,一般设置为300。
batch-size:批处理的大小,根据显存的大小设置。
learning-rate:学习率,一般设置为0.001,如果优化器不同,初始的学习率也要做相应的调整。
load:加载模型的路径,如果接着上次的训练,就需要设置上次训练的权重文件路径,如果有预训练权重,则设置预训练权重的路径。
scale:放大的倍数,这里设置为0.5,把图片大小变为原来的一半。
validation:验证验证集的百分比。
amp:是否使用混合精度?
比较重要的参数是epochs、batch-size和learning-rate,可以反复调整做实验,达到最好的精度。
接下来是设置模型:
net = UNet(n_channels=3, n_classes=2, bilinear=True)
logging.info(f’Network:\n’
f’\t{net.n_channels} input channels\n’
f’\t{net.n_classes} output channels (classes)\n’
f’\t{“Bilinear” if net.bilinear else “Transposed conv”} upscaling’)
if args.load:
net.load_state_dict(torch.load(args.load, map_location=device))
logging.info(f’Model loaded from {args.load}')
设置UNet参数,n_channels是imgs图片的通道数,如果是rgb则是3,如果是黑白图片就是1,n_classes设置为2,在这里把背景也当做一个类别,所以有两个类。
如果设置了权重文件,则加载权重文件,加载权重文件做迁移学习可以加快训练,减少迭代次数,所以如果有还是尽量加载预训练权重。
接下来修改train_net函数的逻辑。
try:
dataset = CarvanaDataset(dir_img, dir_mask, img_scale)
except (AssertionError, RuntimeError):
dataset = BasicDataset(dir_img, dir_mask, img_scale)
2. Split into train / validation partitions
n_val = int(len(dataset) * val_percent)
n_train = len(dataset) - n_val
train_set, v