UNet语义分割实战:使用UNet实现对人物的抠图(1)

本文介绍了如何使用U-Net模型在Carvana数据集上进行图像分割训练,详细说明了关键参数如epochs、batch_size、learning_rate的设置,以及如何加载预训练模型、数据集划分、DataLoader使用和loss计算过程,包括DiceLoss的引入以优化模型性能。
摘要由CSDN通过智能技术生成

class CarvanaDataset(BasicDataset):

def init(self, images_dir, masks_dir, scale=1):

super().init(images_dir, masks_dir, scale, mask_suffix=‘_matte’)

将mask_suffix改为“_matte”

训练

=============================================================

打开train.py,先查看全局参数:

def get_args():

parser = argparse.ArgumentParser(description=‘Train the UNet on images and target masks’)

parser.add_argument(‘–epochs’, ‘-e’, metavar=‘E’, type=int, default=300, help=‘Number of epochs’)

parser.add_argument(‘–batch-size’, ‘-b’, dest=‘batch_size’, metavar=‘B’, type=int, default=16, help=‘Batch size’)

parser.add_argument(‘–learning-rate’, ‘-l’, metavar=‘LR’, type=float, default=0.001,

help=‘Learning rate’, dest=‘lr’)

parser.add_argument(‘–load’, ‘-f’, type=str, default=False, help=‘Load model from a .pth file’)

parser.add_argument(‘–scale’, ‘-s’, type=float, default=0.5, help=‘Downscaling factor of the images’)

parser.add_argument(‘–validation’, ‘-v’, dest=‘val’, type=float, default=10.0,

help=‘Percent of the data that is used as validation (0-100)’)

parser.add_argument(‘–amp’, action=‘store_true’, default=False, help=‘Use mixed precision’)

return parser.parse_args()

epochs:epoch的个数,一般设置为300。

batch-size:批处理的大小,根据显存的大小设置。

learning-rate:学习率,一般设置为0.001,如果优化器不同,初始的学习率也要做相应的调整。

load:加载模型的路径,如果接着上次的训练,就需要设置上次训练的权重文件路径,如果有预训练权重,则设置预训练权重的路径。

scale:放大的倍数,这里设置为0.5,把图片大小变为原来的一半。

validation:验证验证集的百分比。

amp:是否使用混合精度?

比较重要的参数是epochs、batch-size和learning-rate,可以反复调整做实验,达到最好的精度。

接下来是设置模型:

net = UNet(n_channels=3, n_classes=2, bilinear=True)

logging.info(f’Network:\n’

f’\t{net.n_channels} input channels\n’

f’\t{net.n_classes} output channels (classes)\n’

f’\t{“Bilinear” if net.bilinear else “Transposed conv”} upscaling’)

if args.load:

net.load_state_dict(torch.load(args.load, map_location=device))

logging.info(f’Model loaded from {args.load}')

设置UNet参数,n_channels是imgs图片的通道数,如果是rgb则是3,如果是黑白图片就是1,n_classes设置为2,在这里把背景也当做一个类别,

  • 8
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值