【6PACK代码注解】train.py


前言

【6PACK全记录】6-PACK论文学习及复现记录
常见遍历epoch训练的模型:

model = MyModel()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)

for epoch in range(1, epochs):
    for i, (inputs, labels) in enumerate(train_loader):
        output= model(inputs)
        loss = criterion(output, labels)
        
        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

train.py

一、超参数设置:

parser = argparse.ArgumentParser() #实例化参数解析器parser
#添加参数
parser.add_argument('--dataset_root', type=str, default = 'My_NOCS', help='dataset root dir')
parser.add_argument('--resume', type=str, default = '',  help='resume model') 
parser.add_argument('--category', type=int, default = 5,  help='category to train') 
parser.add_argument('--num_points', type=int, default = 500, help='points')
parser.add_argument('--num_cates', type=int, default = 6, help='number of categories')
parser.add_argument('--workers', type=int, default = 5, help='number of data loading workers')
parser.add_argument('--num_kp', type=int, default = 8, help='number of kp')
parser.add_argument('--outf', type=str, default = 'models/', help='save dir')
parser.add_argument('--lr', default=0.0001, help='learning rate')
opt = parser.parse_args() #解析后的参数

超参数含义:

-dataset_root:数据集根路径
-resume:此前保存的训练过的模型
-category:训练的物体类型
-num_points:点云中的点数
-num_cates:物体类别总数,我们的数据集中有6种
-workers:同时工作的线程数
-num_kp:每个实例需要提取的关键点数
-outf:模型保存的路径

二、模型加载:

若resume不为空则加载已经训练好的模型,否则使用network.py中的KeyNet进行训练

model = KeyNet(num_points = opt.num_points, num_key = opt.num_kp, num_cates = opt.num_cates)
model.cuda()

if opt.resume != '':
    model.load_state_dict(torch.load('{0}/{1}'.format(opt.outf, opt.resume)))#加载已经训练好的模型

三、数据集加载:

分别加载训练集(5000样本)和验证集(1000样本),其中训练集加入随机噪声干扰
两者均在1个batch中加载

dataset = Dataset('train', opt.dataset_root, True, opt.num_points, opt.num_cates, 5000, opt.category)#5000是训练集的样本数
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=opt.workers)
test_dataset = Dataset('val', opt.dataset_root, False, opt.num_points, opt.num_cates, 1000, opt.category)
testdataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=True, num_workers=opt.workers)

四、遍历500个epoch训练,每个epoch中先训练再验证:

4.1 训练

for i, data in enumerate(dataloader, 0):
        img_fr, choose_fr, cloud_fr, r_fr, t_fr, img_to, choose_to, cloud_to, r_to, t_to, mesh, anchor, scale, cate = data
        img_fr, choose_fr, cloud_fr, r_fr, t_fr, img_to, choose_to, cloud_to, r_to, t_to, mesh, anchor, scale, cate = Variable(img_fr).cuda(), \
                                                                                                                     Variable(choose_fr).cuda(), \
                                                                                                                     Variable(cloud_fr).cuda(), \
                                                                                                                     Variable(r_fr).cuda(), \
                                                                                                                     Variable(t_fr).cuda(), \
                                                                                                                     Variable(img_to).cuda(), \
                                                                                                                     Variable(choose_to).cuda(), \
                                                                                                                     Variable(cloud_to).cuda(), \
                                                                                                                     Variable(r_to).cuda(), \
                                                                                                                     Variable(t_to).cuda(), \
                                                                                                                     Variable(mesh).cuda(), \
                                                                                                                     Variable(anchor).cuda(), \
                                                                                                                     Variable(scale).cuda(), \
                                                                                                                     Variable(cate).cuda()

        Kp_fr, anc_fr, att_fr = model(img_fr, choose_fr, cloud_fr, anchor, scale, cate, t_fr) #调用keynet中的forward()训练,输出all_kp_x, output_anchor, att_x
        Kp_to, anc_to, att_to = model(img_to, choose_to, cloud_to, anchor, scale, cate, t_to)
        #kp:关键点坐标(实际,非归一化)
        #anc:锚点坐标(实际,非归一化)
        #att:锚点置信分数

        #loss.forward()
        loss, _ = criterion(Kp_fr, Kp_to, anc_fr, anc_to, att_fr, att_to, r_fr, t_fr, r_to, t_to, mesh, scale, cate)
        loss.backward()

        train_dis_avg += loss.item() 
        train_count += 1

        if train_count != 0 and train_count % 8 == 0: #每8个作为1个mini-batch
            optimizer.step() #根据当前梯度优化更新参数
            optimizer.zero_grad() #梯度清空
            print(train_count, float(train_dis_avg) / 8.0) 
            train_dis_avg = 0.0

        if train_count != 0 and train_count % 100 == 0: #每100个保存一次当前模型,如训练过程终端,可以从这里开始
            torch.save(model.state_dict(), '{0}/model_current_{1}.pth'.format(opt.outf, cate_list[opt.category-1]))

该过程中每8个样本作为一个mini-batch进行梯度下降和参数更新,并输出该epoch中当前训练的次数和平均误差。输出形式如下:
在这里插入图片描述
前8行是遍历dataloader时dataset_nocs.py中输出的:

print(tmp_cate_id, item)

之后为输出的训练次数和平均误差。
为防止训练过程中断而被迫重新开始的情况,每训练100个样本保存当前模型,如中断,可以通过“–resume”参数载入最新模型。

4.2 验证

optimizer.zero_grad()
    model.eval() #切换到eval模式,关闭 batch normalization 和 dropout
    score = []
    for j, data in enumerate(testdataloader, 0):
        img_fr, choose_fr, cloud_fr, r_fr, t_fr, img_to, choose_to, cloud_to, r_to, t_to, mesh, anchor, scale, cate = data
        img_fr, choose_fr, cloud_fr, r_fr, t_fr, img_to, choose_to, cloud_to, r_to, t_to, mesh, anchor, scale, cate = Variable(img_fr).cuda(), \
                                                                                                                     Variable(choose_fr).cuda(), \
                                                                                                                     Variable(cloud_fr).cuda(), \
                                                                                                                     Variable(r_fr).cuda(), \
                                                                                                                     Variable(t_fr).cuda(), \
                                                                                                                     Variable(img_to).cuda(), \
                                                                                                                     Variable(choose_to).cuda(), \
                                                                                                                     Variable(cloud_to).cuda(), \
                                                                                                                     Variable(r_to).cuda(), \
                                                                                                                     Variable(t_to).cuda(), \
                                                                                                                     Variable(mesh).cuda(), \
                                                                                                                     Variable(anchor).cuda(), \
                                                                                                                     Variable(scale).cuda(), \
                                                                                                                     Variable(cate).cuda()

        Kp_fr, anc_fr, att_fr = model(img_fr, choose_fr, cloud_fr, anchor, scale, cate, t_fr)
        Kp_to, anc_to, att_to = model(img_to, choose_to, cloud_to, anchor, scale, cate, t_to)

        _, item_score = criterion(Kp_fr, Kp_to, anc_fr, anc_to, att_fr, att_to, r_fr, t_fr, r_to, t_to, mesh, scale, cate)
        
        print(item_score)
        score.append(item_score)

遍历验证集,并输出每个样本的估计误差。若当前epoch验证集所有样本的误差平均值达到最小,保存该epoch训练出的模型。
输出形式如下:
在这里插入图片描述
每个样本对应两行输出,第一行为遍历dataloader时dataset_nocs.py中输出的:

print(tmp_cate_id, item)

第二行为该样本的估计误差。

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值