Pointnet代码详解:语义分割部分(有注解)

主代码

import argparse
import os
import torch
import torch.nn.parallel
import torch.utils.data
from utils import to_categorical
from collections import defaultdict
from torch.autograd import Variable
from data_utils.ShapeNetDataLoader import PartNormalDataset
import torch.nn.functional as F
import datetime
import logging
from pathlib import Path
from utils import test_partseg
from tqdm import tqdm
from model.pointnet2 import PointNet2PartSeg_msg_one_hot
from model.pointnet import PointNetDenseCls,PointNetLoss

seg_classes = {'Earphone': [16, 17, 18], 'Motorbike': [30, 31, 32, 33, 34, 35], 'Rocket': [41, 42, 43], 'Car': [8, 9, 10, 11], 'Laptop': [28, 29], 'Cap': [6, 7], 'Skateboard': [44, 45, 46], 'Mug': [36, 37], 'Guitar': [19, 20, 21], 'Bag': [4, 5], 'Lamp': [24, 25, 26, 27], 'Table': [47, 48, 49], 'Airplane': [0, 1, 2, 3], 'Pistol': [38, 39, 40], 'Chair': [12, 13, 14, 15], 'Knife': [22, 23]}
seg_label_to_cat = {} # {0:Airplane, 1:Airplane, ...49:Table}
for cat in seg_classes.keys():
    for label in seg_classes[cat]:
        seg_label_to_cat[label] = cat




def parse_args():
    parser = argparse.ArgumentParser('PointNet2')
    parser.add_argument('--batchsize', type=int, default=8, help='input batch size')
    parser.add_argument('--workers', type=int, default=0, help='number of data loading workers')
    parser.add_argument('--epoch', type=int, default=4, help='number of epochs for training')
    parser.add_argument('--pretrain', type=str, default=None,help='whether use pretrain model')
    parser.add_argument('--gpu', type=str, default='0', help='specify gpu device')
    parser.add_argument('--model_name', type=str, default='pointnet', help='Name of model')
    parser.add_argument('--learning_rate', type=float, default=0.001, help='learning rate for training')
    parser.add_argument('--decay_rate', type=float, default=1e-4, help='weight decay')
    parser.add_argument('--optimizer', type=str, default='Adam', help='type of optimizer')
    parser.add_argument('--multi_gpu', type=str, default=None, help='whether use multi gpu training')
    parser.add_argument('--jitter', default=False, help="randomly jitter point cloud")
    parser.add_argument('--step_size', type=int, default=20, help="randomly rotate point cloud")

    return parser.parse_args()

def main(args):
    #创建文件夹
    # os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu if args.multi_gpu is None else '0,1,2,3'
    '''CREATE DIR'''
    experiment_dir = Path('./experiment/')
    experiment_dir.mkdir(exist_ok=True)

    file_dir = Path(str(experiment_dir) +'/%sPartSeg-'%args.model_name + str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')))
    file_dir.mkdir(exist_ok=True)

    checkpoints_dir = file_dir.joinpath('checkpoints/')
    checkpoints_dir.mkdir(exist_ok=True)

    log_dir = file_dir.joinpath('logs/')
    log_dir.mkdir(exist_ok=True)

    '''LOG'''
    #使用logging
    args = parse_args()
    logger = logging.getLogger(args.model_name)#设置logger 记录器
    logger.setLevel(logging.INFO)#设置等级
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')#设置输出的布局

    file_handler = logging.FileHandler(str(log_dir) + '/train_%s_partseg.txt'%args.model_name)#设置handler处理器
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)

    logger.info('---------------------------------------------------TRANING---------------------------------------------------')
    logger.info('PARAMETER ...')
    logger.info(args)
    norm = True if args.model_name == 'pointnet' else False

    #数据集加载
    TRAIN_DATASET = PartNormalDataset(npoints=2048, split='trainval',normalize=norm, jitter=args.jitter)
    dataloader = torch.utils.data.DataLoader(TRAIN_DATASET, batch_size=args.batchsize,shuffle=True, num_workers=int(args.workers))
    TEST_DATASET = PartNormalDataset(npoints=2048, split='test',normalize=norm,jitter=False)
    testdataloader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=8,shuffle=True, num_workers=int(args.workers))
    print("The number of training data is:",len(TRAIN_DATASET))
    logger.info("The number of training data is:%d",len(TRAIN_DATASET))
    print("The number of test data is:", len(TEST_DATASET))
    logger.info("The number of test data is:%d", len(TEST_DATASET))
    num_classes = 16
    num_part = 50
    blue = lambda x: '\033[94m' + x + '\033[0m'
    model = PointNet2PartSeg_msg_one_hot(num_part) if args.model_name == 'pointnet2'else PointNetDenseCls(cat_num=num_classes,part_num=num_part)

    if args.pretrain is not None:
        model.load_state_dict(torch.load(args.pretrain))
        print('load model %s'%args.pretrain)
        logger.info('load model %s'%args.pretrain)
    else:
        print('Training from scratch')
        logger.info('Training from scratch')
    pretrain = args.pretrain
    init_epoch = int(pretrain[-14:-11]) if args.pretrain is not None else 0


    if args.optimizer == 'SGD':
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    elif args.optimizer == 'Adam':
        optimizer = torch.optim.Adam(
            model.parameters(),
            lr=args.learning_rate,
            betas=(0.9, 0.999),
            eps=1e-08,
            weight_decay=args.decay_rate
        )
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=0.5)#调整学习率的方法,根据epoc
  • 0
    点赞
  • 23
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值