主代码
import argparse
import os
import torch
import torch.nn.parallel
import torch.utils.data
from utils import to_categorical
from collections import defaultdict
from torch.autograd import Variable
from data_utils.ShapeNetDataLoader import PartNormalDataset
import torch.nn.functional as F
import datetime
import logging
from pathlib import Path
from utils import test_partseg
from tqdm import tqdm
from model.pointnet2 import PointNet2PartSeg_msg_one_hot
from model.pointnet import PointNetDenseCls,PointNetLoss
seg_classes = {'Earphone': [16, 17, 18], 'Motorbike': [30, 31, 32, 33, 34, 35], 'Rocket': [41, 42, 43], 'Car': [8, 9, 10, 11], 'Laptop': [28, 29], 'Cap': [6, 7], 'Skateboard': [44, 45, 46], 'Mug': [36, 37], 'Guitar': [19, 20, 21], 'Bag': [4, 5], 'Lamp': [24, 25, 26, 27], 'Table': [47, 48, 49], 'Airplane': [0, 1, 2, 3], 'Pistol': [38, 39, 40], 'Chair': [12, 13, 14, 15], 'Knife': [22, 23]}
seg_label_to_cat = {} # {0:Airplane, 1:Airplane, ...49:Table}
for cat in seg_classes.keys():
for label in seg_classes[cat]:
seg_label_to_cat[label] = cat
def parse_args():
parser = argparse.ArgumentParser('PointNet2')
parser.add_argument('--batchsize', type=int, default=8, help='input batch size')
parser.add_argument('--workers', type=int, default=0, help='number of data loading workers')
parser.add_argument('--epoch', type=int, default=4, help='number of epochs for training')
parser.add_argument('--pretrain', type=str, default=None,help='whether use pretrain model')
parser.add_argument('--gpu', type=str, default='0', help='specify gpu device')
parser.add_argument('--model_name', type=str, default='pointnet', help='Name of model')
parser.add_argument('--learning_rate', type=float, default=0.001, help='learning rate for training')
parser.add_argument('--decay_rate', type=float, default=1e-4, help='weight decay')
parser.add_argument('--optimizer', type=str, default='Adam', help='type of optimizer')
parser.add_argument('--multi_gpu', type=str, default=None, help='whether use multi gpu training')
parser.add_argument('--jitter', default=False, help="randomly jitter point cloud")
parser.add_argument('--step_size', type=int, default=20, help="randomly rotate point cloud")
return parser.parse_args()
def main(args):
#创建文件夹
# os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu if args.multi_gpu is None else '0,1,2,3'
'''CREATE DIR'''
experiment_dir = Path('./experiment/')
experiment_dir.mkdir(exist_ok=True)
file_dir = Path(str(experiment_dir) +'/%sPartSeg-'%args.model_name + str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')))
file_dir.mkdir(exist_ok=True)
checkpoints_dir = file_dir.joinpath('checkpoints/')
checkpoints_dir.mkdir(exist_ok=True)
log_dir = file_dir.joinpath('logs/')
log_dir.mkdir(exist_ok=True)
'''LOG'''
#使用logging
args = parse_args()
logger = logging.getLogger(args.model_name)#设置logger 记录器
logger.setLevel(logging.INFO)#设置等级
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')#设置输出的布局
file_handler = logging.FileHandler(str(log_dir) + '/train_%s_partseg.txt'%args.model_name)#设置handler处理器
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
logger.info('---------------------------------------------------TRANING---------------------------------------------------')
logger.info('PARAMETER ...')
logger.info(args)
norm = True if args.model_name == 'pointnet' else False
#数据集加载
TRAIN_DATASET = PartNormalDataset(npoints=2048, split='trainval',normalize=norm, jitter=args.jitter)
dataloader = torch.utils.data.DataLoader(TRAIN_DATASET, batch_size=args.batchsize,shuffle=True, num_workers=int(args.workers))
TEST_DATASET = PartNormalDataset(npoints=2048, split='test',normalize=norm,jitter=False)
testdataloader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=8,shuffle=True, num_workers=int(args.workers))
print("The number of training data is:",len(TRAIN_DATASET))
logger.info("The number of training data is:%d",len(TRAIN_DATASET))
print("The number of test data is:", len(TEST_DATASET))
logger.info("The number of test data is:%d", len(TEST_DATASET))
num_classes = 16
num_part = 50
blue = lambda x: '\033[94m' + x + '\033[0m'
model = PointNet2PartSeg_msg_one_hot(num_part) if args.model_name == 'pointnet2'else PointNetDenseCls(cat_num=num_classes,part_num=num_part)
if args.pretrain is not None:
model.load_state_dict(torch.load(args.pretrain))
print('load model %s'%args.pretrain)
logger.info('load model %s'%args.pretrain)
else:
print('Training from scratch')
logger.info('Training from scratch')
pretrain = args.pretrain
init_epoch = int(pretrain[-14:-11]) if args.pretrain is not None else 0
if args.optimizer == 'SGD':
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
elif args.optimizer == 'Adam':
optimizer = torch.optim.Adam(
model.parameters(),
lr=args.learning_rate,
betas=(0.9, 0.999),
eps=1e-08,
weight_decay=args.decay_rate
)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=0.5)#调整学习率的方法,根据epoc