代码:
from data_utils.ModelNetDataLoader import ModelNetDataLoader import argparse import numpy as np import os import torch import datetime import logging from pathlib import Path from tqdm import tqdm import sys import provider import importlib import shutil BASE_DIR = os.path.dirname(os.path.abspath(__file__)) ROOT_DIR = BASE_DIR sys.path.append(os.path.join(ROOT_DIR, 'models')) """ 需要配置的参数: --model pointnet2_cls_msg --normal --log_dir pointnet2_cls_msg """ def parse_args(): '''PARAMETERS''' parser = argparse.ArgumentParser('PointNet') parser.add_argument('--batch_size', type=int, default=8, help='batch size in training [default: 24]') parser.add_argument('--model', default='pointnet2_cls_msg', help='model name [default: pointnet_cls]') parser.add_argument('--epoch', default=200, type=int, help='number of epoch in training [default: 200]') parser.add_argument('--learning_rate', default=0.001, type=float, help='learning rate in training [default: 0.001]') parser.add_argument('--gpu', type=str, default='0', help='specify gpu device [default: 0]') parser.add_argument('--num_point', type=int, default=1024, help='Point Number [default: 1024]') parser.add_argument('--optimizer', type=str, default='Adam', help='optimizer for training [default: Adam]') parser.add_argument('--log_dir', type=str, default=None, help='experiment root') parser.add_argument('--decay_rate', type=float, default=1e-4, help='decay rate [default: 1e-4]') parser.add_argument('--normal', action='store_true', default=False, help='Whether to use normal information [default: False]') return parser.parse_args() def test(model, loader, num_class=40): mean_correct = [] class_acc = np.zeros((num_class,3)) for j, data in tqdm(enumerate(loader), total=len(loader)): points, target = data target = target[:, 0] points = points.transpose(2, 1) points, target = points.cuda(), target.cuda() classifier = model.eval() pred, _ = classifier(points) pred_choice = pred.data.max(1)[1] for cat in np.unique(target.cpu()): classacc = pred_choice[target==cat].eq(target[target==cat].long().data).cpu().sum() class_acc[cat,0]+= classacc.item()/float(points[target==cat].size()[0]) class_acc[cat,1]+=1 correct = pred_choice.eq(target.long().data).cpu().sum() mean_correct.append(correct.item()/float(points.size()[0])) class_acc[:,2] = class_acc[:,0]/ class_acc[:,1] class_acc = np.mean(class_acc[:,2]) instance_acc = np.mean(mean_correct) return instance_acc, class_acc def main(args): def log_string(str): logger.info(str) print(str) '''HYPER PARAMETER''' os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu '''CREATE DIR''' timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')) experiment_dir = Path('./log/') experiment_dir.mkdir(exist_ok=True) experiment_dir = experiment_dir.joinpath('classification') experiment_dir.mkdir(exist_ok=True) if args.log_dir is None: experiment_dir = experiment_dir.joinpath(timestr) else: experiment_dir = experiment_dir.joinpath(args.log_dir) experiment_dir.mkdir(exist_ok=True) checkpoints_dir = experiment_dir.joinpath('checkpoints/') checkpoints_dir.mkdir(exist_ok=True) log_dir = experiment_dir.joinpath('logs/') log_dir.mkdir(exist_ok=True) '''LOG''' args = parse_args() logger = logging.getLogger("Model") logger.setLevel(logging.INFO) formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model)) file_handler.setLevel(logging.INFO) file_handler.setFormatter(formatter) logger.addHandler(file_handler) log_string('PARAMETER ...') log_string(args) '''DATA LOADING''' log_string('Load dataset ...') DATA_PATH = 'data/modelnet40_normal_resampled/' TRAIN_DATASET = ModelNetDataLoader(root=DATA_PATH, npoint=args.num_point, split='train', normal_channel=args.normal) TEST_DATASET = ModelNetDataLoader(root=DATA_PATH, npoint=args.num_point, split='test', normal_channel=args.normal) trainDataLoader = torch.utils.data.DataLoader(TRAIN_DATASET, batch_size=args.batch_size, shuffle=True, num_workers=4) testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=args.batch_size, shuffle=False, num_workers=4) '''MODEL LOADING''' num_class = 40 MODEL = importlib.import_module(args.model) shutil.copy('./models/%s.py' % args.model, str(experiment_dir)) shutil.copy('./models/pointnet_util.py', str(experiment_dir)) classifier = MODEL.get_model(num_class,normal_channel=args.normal).cuda() criterion = MODEL.get_loss().cuda() try: checkpoint = torch.load(str(experiment_dir) + '/checkpoints/best_model.pth') start_epoch = checkpoint['epoch'] classifier.load_state_dict(checkpoint['model_state_dict']) log_string('Use pretrain model') except: log_string('No existing model, starting training from scratch...') start_epoch = 0 if args.optimizer == 'Adam': optimizer = torch.optim.Adam( classifier.parameters(), lr=args.learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=args.decay_rate ) else: optimizer = torch.optim.SGD(classifier.parameters(), lr=0.01, momentum=0.9) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.7) global_epoch = 0 global_step = 0 best_instance_acc = 0.0 best_class_acc = 0.0 mean_correct = [] '''TRANING''' logger.info('Start training...') for epoch in range(start_epoch,args.epoch): log_string('Epoch %d (%d/%s):' % (global_epoch + 1, epoch + 1, args.epoch)) scheduler.step() for batch_id, data in tqdm(enumerate(trainDataLoader, 0), total=len(trainDataLoader), smoothing=0.9): points, target = data points = points.data.numpy() points = provider.random_point_dropout(points) points[:,:, 0:3] = provider.random_scale_point_cloud(points[:,:, 0:3]) points[:,:, 0:3] = provider.shift_point_cloud(points[:,:, 0:3]) points = torch.Tensor(points) target = target[:, 0] points = points.transpose(2, 1) points, target = points.cuda(), target.cuda() optimizer.zero_grad() classifier = classifier.train() pred, trans_feat = classifier(points) loss = criterion(pred, target.long(), trans_feat) pred_choice = pred.data.max(1)[1] correct = pred_choice.eq(target.long().data).cpu().sum() mean_correct.append(correct.item() / float(points.size()[0])) loss.backward() optimizer.step() global_step += 1 train_instance_acc = np.mean(mean_correct) log_string('Train Instance Accuracy: %f' % train_instance_acc) with torch.no_grad(): instance_acc, class_acc = test(classifier.eval(), testDataLoader) if (instance_acc >= best_instance_acc): best_instance_acc = instance_acc best_epoch = epoch + 1 if (class_acc >= best_class_acc): best_class_acc = class_acc log_string('Test Instance Accuracy: %f, Class Accuracy: %f'% (instance_acc, class_acc)) log_string('Best Instance Accuracy: %f, Class Accuracy: %f'% (best_instance_acc, best_class_acc)) if (instance_acc >= best_instance_acc): logger.info('Save model...') savepath = str(checkpoints_dir) + '/best_model.pth' log_string('Saving at %s'% savepath) state = { 'epoch': best_epoch, 'instance_acc': instance_acc, 'class_acc': class_acc, 'model_state_dict': classifier.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), } torch.save(state, savepath) global_epoch += 1 logger.info('End of training...') if __name__ == '__main__': args = parse_args() main(args)
代码解释:
这段代码是一个用于训练和测试模型的主要脚本。
首先导入所需的库和模块,包括数据加载器、命令行参数解析、numpy、操作系统、torch、datetime、logging、Path和tqdm等。
然后定义了两个全局变量,BASE_DIR和ROOT_DIR,以及将ROOT_DIR添加到sys.path中。
接下来定义了一个函数parse_args(),用于解析命令行参数并返回参数值。
定义了一个函数test(model, loader, num_class),用于对模型进行测试。函数采用mini-batch方式处理数据,计算模型在每个测试批次上的准确率,并计算每个类别的平均准确率。
主函数main(args)是整个脚本的核心逻辑。在函数中,首先定义了一个内部函数log_string(str),用于记录日志信息和打印信息。
然后根据args.gpu参数设置环境变量CUDA_VISIBLE_DEVICES。
接着创建实验目录,并定义日志器logger,将日志写入文件中。
解析命令行参数,打印参数信息。
加载数据集,包括训练数据集和测试数据集。
加载模型,根据args.normal参数调用合适的模型构造函数。同时将模型和损失函数移动到GPU上。
根据args.optimizer参数选择合适的优化器。
开始训练循环。在每个epoch上,遍历训练数据集,进行模型训练。其中,模型根据points和trans_feat计算预测结果,然后计算损失值,进行反向传播和参数更新。同时计算准确率。
在训练过程中,定期进行测试,计算模型在测试数据集上的实例准确率和类别准确率。根据最佳的实例准确率和类别准确率保存最佳模型。
最后,在main函数中解析命令行参数,调用main函数开始训练和测试。
总的来说,这段代码提供了一个训练和测试模型的框架。您可以根据自己的需求和具体的模型进行适当修改和定制。