论文复现教程-(3DGNN)3D Graph Neural Networks for RGBD Semantic Segmentation

一. 下载源程序

源程序github地址:https://github.com/yanx27/3DGNN_pytorch

下载源程序并解压
在这里插入图片描述

二. 准备数据集

数据集地址:https://github.com/yanx27/3DGNN_pytorch

下载其中的标签数据集 NYU_Depth_V2,共计2.8G
在这里插入图片描述
github下载速度较慢,可以使用 数据集百度云下载地址, 密码: bfi4

存储路径和名称为:

./datasets/data/nyu_depth_v2_labeled.mat

三. 准备hha数据

  1. 下载程序并解压(该程序用于生成hha数据)
    程序地址:https://github.com/charlesCXK/Depth2HHA
    在这里插入图片描述

  2. 上述程序为Matlab程序,需要使用 NYU_Depth_V2数据集中的图片,百度网盘有直接处理好的图片,可以直接下载(链接,提取码: f1n8)。
    在这里插入图片描述
    图像提取作者原文链接:https://blog.csdn.net/sinat_26871259/article/details/82351276

  3. nyu_depths 文件夹中的图片复制到上文下载的 depth 文件夹中
    在这里插入图片描述

  4. 对源程序做如下修改:

  • 修改 main.m
clc;
addpath('./utils/nyu-hooks');
depth_image_root = './depth'       % dir where depth images are in.
rawdepth_image_root = './rawdepth'       % dir where raw depth images are in.
hha_image_root = './hha'

C = getCameraParam('color');

for i=0:1448
    %i
    matrix = C;    %camera_matrix(1+(i-1)*3:i*3,:);        % matrix of this image, 3*3
    D = imread(fullfile(depth_image_root, '/', [mat2str(i),'.png']));
    % here, RD is the same as D, because there is some problem about NYU Depth V2 raw-depth-images
    RD = imread(fullfile(depth_image_root, '/', [mat2str(i),'.png']));
    hha = saveHHA([mat2str(i)], matrix, hha_image_root, D, RD);
    % hha = saveHHA(['complete_img_', mat2str(5000+i)], matrix, hha_image_root, D, D);
end 

上述程序较github源程序改动较多,主要有:
1)循环次数;
2)RD读取路径,没有选择 rawdepth_image_root据说这个数据集中raw-depth数据有误
3)文件读取和保存去掉了 img_ 的开头,命名也从0开始,而没有从5001开始

  • 修改saveHHA.m
D = double(D)/1000;        % The unit of the element inside D is 'centimeter'

注意: 如果除以1000,得到的图片都会偏红,为此我经过多次尝试,将此数改为25,得到的hha图片下过如下:
在这里插入图片描述

  1. 运行上述程序,生成hha数据
    在这里插入图片描述
  2. 将上述图片保存到Ubuntu环境下,保存路径为:

./datasets/data/hha/

在这里插入图片描述

四. 载入数据集

  1. 修改 nyudv2.py 文件(该文件作用为读取并保存数据集到类的实例)
f = h5py.File(data_path + data_file, 'r')

添加 ‘r’ 可以避免输出警告信息

hha = np.transpose(cv2.imread("datasets/data/hha/" + str(idx) + ".png", cv2.COLOR_BGR2RGB), [1, 0, 2])

str(idx+1) 改为 str(idx),因为hha图片从0开始编号

nyudv2.py文件的全文注释如下所示:

from torch.utils.data import Dataset
import glob
import numpy as np
import cv2
import h5py


class Dataset(Dataset):
    def __init__(self, flip_prob=None, crop_type=None, crop_size=0):
        
        '''参数初始化,是否需要随机翻转,是否需要随机裁减'''
        self.flip_prob = flip_prob
        self.crop_type = crop_type
        self.crop_size = crop_size
        
        '''需要读取的文件路径及文件名称'''
        data_path = 'datasets/data/'
        data_file = 'nyu_depth_v2_labeled.mat'

        # 读取 mat 文件
        print("Reading .mat file...")
        f = h5py.File(data_path + data_file,'r')

        # as it turns out, trying to pickle this is a shit idea :D
        rgb_images_fr = np.transpose(f['images'], [0, 2, 3, 1]).astype(np.float32)
        label_images_fr = np.array(f['labels'])
        
        '''关闭文件'''
        f.close()
        
        '''通过在prediction.py文件中将变量维度输出,得知该变量维度为(1449, 640, 480, 3)'''
        self.rgb_images = rgb_images_fr
        '''通过在prediction.py文件中将变量维度输出,得知该变量维度为(1449, 640, 480)'''
        self.label_images = label_images_fr
        
    def __len__(self):
        '''如果是多维矩阵,一般返回第一个维度'''
        return len(self.rgb_images)

    def __getitem__(self, idx):
    
        '''rgb存储rgb图片,维度为(640, 480, 3)'''
        rgb = self.rgb_images[idx].astype(np.float32)
        '''hha存储hha图片,维度为(640, 480, 3)'''
        hha = np.transpose(cv2.imread("datasets/data/hha/" + str(idx) + ".png", cv2.COLOR_BGR2RGB), [1, 0, 2])
        '''rgb_hha将rgb图片和hha图片进行了合并 维度 (640, 480, 6)'''
        rgb_hha = np.concatenate([rgb, hha], axis=2).astype(np.float32)
        '''维度 (640, 480)'''
        label = self.label_images[idx].astype(np.float32)
        label[label >= 14] = 0
        '''构造和rgb维度相同的零矩阵,因为有0:2,所以维度为 (640, 480, 2)'''
        xy = np.zeros_like(rgb)[:,:,0:2].astype(np.float32)

        # random crop 随机裁减 不但提高了模型精度,也增强了模型稳定性,
        if self.crop_type is not None and self.crop_size > 0:
            max_margin = rgb_hha.shape[0] - self.crop_size
            if max_margin == 0:  # crop is original size, so nothing to crop
                self.crop_type = None
            elif self.crop_type == 'Center':
                rgb_hha = rgb[max_margin // 2:-max_margin // 2, max_margin // 2:-max_margin // 2, :]
                label = label[max_margin // 2:-max_margin // 2, max_margin // 2:-max_margin // 2]
                xy = xy[max_margin // 2:-max_margin // 2, max_margin // 2:-max_margin // 2, :]
            elif self.crop_type == 'Random':
                x_ = np.random.randint(0, max_margin)
                y_ = np.random.randint(0, max_margin)
                rgb_hha = rgb_hha[y_:y_ + self.crop_size, x_:x_ + self.crop_size, :]
                label = label[y_:y_ + self.crop_size, x_:x_ + self.crop_size]
                xy = xy[y_:y_ + self.crop_size, x_:x_ + self.crop_size, :]
            else:
                print('Bad crop')  # TODO make this more like, you know, good software
                exit(0)

        # random flip 随机翻转(一行的左右进行翻转),提高模型范化能力
        if self.flip_prob is not None:
            if np.random.random() > self.flip_prob:
                rgb_hha = np.fliplr(rgb_hha).copy()
                label = np.fliplr(label).copy()
                xy = np.fliplr(xy).copy()

        '''return 实际上也就确定了数据集的格式 分别是rgb_hha, label, xy数据'''
        return rgb_hha, label, xy
  1. 修改 run.py 文件(模型的训练和预测,最重要的程序)
  • 修改epoch,循环次数
  • 修改batch,每次批处理图片的数量
  • 修改初始参数,可以使用原来训练好的参数
  • 修改默认使用的GPU
    。。。。。。

run.py全文注释如下所示:

import cv2
import os
import sys
import time
import numpy as np
import datetime
import logging
import torch
import torch.optim
import torch.nn as nn
from torch.utils.data import DataLoader
from datasets import nyudv2
from models import Model
import config
from tqdm import tqdm
import argparse

torch.multiprocessing.set_sharing_strategy('file_system')
torch.backends.cudnn.benchmark = True

'''参数初始化'''
def parse_args():
    '''PARAMETERS'''
    parser = argparse.ArgumentParser('3dgnn')
    
    '''指定循环次数'''
    parser.add_argument('--num_epochs', default=100,type=int,
                        help='Number of epoch')
    
    '''指定批量大小'''
    parser.add_argument('--batchsize', type=int, default=6,
                        help='batch size in training')
                        
    '''修改default,可以使用训练好的参数'''                    
    parser.add_argument('--pretrain', type=str, default='experiment/2021-01-04-10/save/3dgnn_finish.pth',
                        help='Direction for pretrained weight')
    
    '''指定GPU,一般从0开始编号'''
    parser.add_argument('--gpu', type=str, default='0,1',
                        help='specify gpu device')

    return parser.parse_args()

def main(args):

    '''指定GPU'''
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    
    '''指定日志文件的存放位置'''
    logger = logging.getLogger('3dgnn')
    log_path = './experiment/'+ str(datetime.datetime.now().strftime('%Y-%m-%d-%H')).replace(' ', '/') + '/'
    print('log path is:',log_path)
    if not os.path.exists(log_path):
        os.makedirs(log_path)
        os.makedirs(log_path + 'save/')
    hdlr = logging.FileHandler(log_path + 'log.txt')
    formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
    hdlr.setFormatter(formatter)
    logger.addHandler(hdlr)
    logger.setLevel(logging.INFO)
    logger.info("Loading data...")
    print("Loading data...")
    
    '''创建字典,将标签和数字'''
    label_to_idx = {'<UNK>': 0, 'beam': 1, 'board': 2, 'bookcase': 3, 'ceiling': 4, 'chair': 5, 'clutter': 6,
                    'column': 7,
                    'door': 8, 'floor': 9, 'sofa': 10, 'table': 11, 'wall': 12, 'window': 13}

    idx_to_label = {0: '<UNK>', 1: 'beam', 2: 'board', 3: 'bookcase', 4: 'ceiling', 5: 'chair', 6: 'clutter',
                    7: 'column',
                    8: 'door', 9: 'floor', 10: 'sofa', 11: 'table', 12: 'wall', 13: 'window'}


    '''config在该文件夹下有相应的python源代码,Dataset指定了相应的训练集'''
    dataset_tr = nyudv2.Dataset(flip_prob=config.flip_prob,crop_type='Random',crop_size=config.crop_size)
    dataloader_tr = DataLoader(dataset_tr, batch_size=args.batchsize, shuffle=True,
                               num_workers=config.workers_tr, drop_last=False, pin_memory=True)

    '''dataset_va是预测集'''
    dataset_va = nyudv2.Dataset(flip_prob=0.0,crop_type='Center',crop_size=config.crop_size)
    dataloader_va = DataLoader(dataset_va, batch_size=args.batchsize, shuffle=False,
                               num_workers=config.workers_va, drop_last=False, pin_memory=True)
    cv2.setNumThreads(config.workers_tr)

    '''日志文件需要添加的信息'''
    logger.info("Preparing model...")
    print("Preparing model...")
    
    '''模型初始化'''
    model = Model(config.nclasses, config.mlp_num_layers,config.use_gpu)
    loss = nn.NLLLoss(reduce=not config.use_bootstrap_loss, weight=torch.FloatTensor(config.class_weights))
    
    '''dim表示维度,dim=0,表示行,dim=1,表示列'''
    softmax = nn.Softmax(dim=1)
    log_softmax = nn.LogSoftmax(dim=1)

    '''使用cuda加速'''
    if config.use_gpu:
        model = model.cuda()
        loss = loss.cuda()
        softmax = softmax.cuda()
        log_softmax = log_softmax.cuda()
    
    '''优化器,选择Adam'''
    optimizer = torch.optim.Adam([{'params': model.decoder.parameters()},
                                  {'params': model.gnn.parameters(), 'lr': config.gnn_initial_lr}],
                                 lr=config.base_initial_lr, betas=config.betas, eps=config.eps, weight_decay=config.weight_decay)
    
    '''学习率调整策略,exp指指数衰减调整,plateau指自适应调整'''
    '''https://blog.csdn.net/shanglianlm/article/details/85143614'''
    '''lambda为匿名函数,epoch为目前循环的次数 https://www.cnblogs.com/huangbiquan/p/8030298.html'''
    if config.lr_schedule_type == 'exp':
        lambda1 = lambda epoch: pow((1 - ((epoch - 1) / args.num_epochs)), config.lr_decay)
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1)
    elif config.lr_schedule_type == 'plateau':
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=config.lr_decay, patience=config.lr_patience)
    else:
        print('bad scheduler')
        exit(1)

    '''记录训练参数数量'''
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    logger.info("Number of trainable parameters: %d", params)


    '''得到目前的学习率'''
    def get_current_learning_rates():
        learning_rates = []
        for param_group in optimizer.param_groups:
            learning_rates.append(param_group['lr'])
        return learning_rates


    '''评估/预测,对输入数据进行评估/预测'''
    def eval_set(dataloader):
        model.eval()
        
        '''torch.no_grad()函数使得程序不计算梯度,只进行前向传播,用在预测中正合适'''
        with torch.no_grad():
            loss_sum = 0.0
            
            '''混淆矩阵'''
            confusion_matrix = torch.cuda.FloatTensor(np.zeros(14 ** 2))

            start_time = time.time()

            '''tqdm是进度条模块'''
            for batch_idx, rgbd_label_xy in tqdm(enumerate(dataloader),total = len(dataloader),smoothing=0.9):
                
                #torch.Tensor 类型 [6, 640, 480, 6] 对应数据集中的 rgb_hha 矩阵
                #第一个6是batch大小,640和480是宽和高,第二个6是因为其为rgb与hha的拼接
                x = rgbd_label_xy[0]
                
                #torch.Tensor 类型 [6, 640, 480, 2] 对应数据集中的 xy 矩阵
                #第一个6是batch大小,640和480是宽和高,2是因为 xy 数据集只有两个维度,且为全零矩阵
                xy = rgbd_label_xy[2]
                
                #torch.Tensor 类型 [6, 640, 480] 第一个6是batch大小,640和480是宽和高
                target = rgbd_label_xy[1].long()
                
                x = x.float()
                xy = xy.float()
                
                '''permute函数用于转换Tensor的维度,contiguous()使得内存是连续的'''
                input = x.permute(0, 3, 1, 2).contiguous()
                xy = xy.permute(0, 3, 1, 2).contiguous()
                if config.use_gpu:
                    input = input.cuda()
                    xy = xy.cuda()
                    target = target.cuda()
                
                '''经过网络,计算输出, 维度为 ([6, 14, 640, 480])'''
                output = model(input, gnn_iterations=config.gnn_iterations, k=config.gnn_k, xy=xy, use_gnn=config.use_gnn)
                
                '''config.use_bootstrap_loss为False'''
                if config.use_bootstrap_loss:
                    loss_per_pixel = loss.forward(log_softmax(output.float()), target)
                    topk, indices = torch.topk(loss_per_pixel.view(output.size()[0], -1),
                                               int((config.crop_size ** 2) * config.bootstrap_rate))
                    loss_ = torch.mean(topk)
                else:
                    '''log_softmax在softmax的结果上再做多一次log运算'''
                    loss_ = loss.forward(log_softmax(output.float()), target)
                loss_sum += loss_

                '''pred维度为 ([6, 640, 480, 14]), 连续内存'''
                pred = output.permute(0, 2, 3, 1).contiguous()
                '''此时pred维度为 ([1843200, 14]), 其中1843200=6*640*480  config.nclasses=14'''
                pred = pred.view(-1, config.nclasses)
                '''每一行进行softmax运算,相当于对每一个像素的分类进行softmax运算'''
                pred = softmax(pred)
                '''pred_max_val, pred_arg_max都是1843200维,分别存储每个像素最大的分类值及分类'''
                pred_max_val, pred_arg_max = pred.max(1)

                '''pairs为1843200维'''
                pairs = target.view(-1) * 14 + pred_arg_max.view(-1)
                
                '''计算混淆矩阵'''
                for i in range(14 ** 2):
                    cumu = pairs.eq(i).float().sum()
                    confusion_matrix[i] += cumu.item()

            sys.stdout.write(" - Eval time: {:.2f}s \n".format(time.time() - start_time))
            loss_sum /= len(dataloader)

            confusion_matrix = confusion_matrix.cpu().numpy().reshape((14, 14))
            class_iou = np.zeros(14)
            confusion_matrix[0, :] = np.zeros(14)
            confusion_matrix[:, 0] = np.zeros(14)
            
            '''计算交并比'''
            for i in range(1, 14):
                class_iou[i] = confusion_matrix[i, i] / (
                        np.sum(confusion_matrix[i, :]) + np.sum(confusion_matrix[:, i]) - confusion_matrix[i, i])

        return loss_sum.item(), class_iou, confusion_matrix

    '''Training parameter 训练参数'''
    model_to_load = args.pretrain
    logger.info("num_epochs: %d", args.num_epochs)
    print("Number of epochs: %d"%args.num_epochs)
    interval_to_show = 200

    train_losses = []
    eval_losses = []

    '''判断使用原来训练过的模型参数,还是从零开始训练'''
    if model_to_load:
        logger.info("Loading old model...")
        print("Loading old model...")
        model.load_state_dict(torch.load(model_to_load))
    else:
        logger.info("Starting training from scratch...")
        print("Starting training from scratch...")

    
    '''Training'''
    '''循环训练,range范围是最后一个参数-1,所以想要实现指定次数,需要+1'''
    for epoch in range(1, args.num_epochs + 1):
        batch_loss_avg = 0
        
        '''学习率更新参数'''
        if config.lr_schedule_type == 'exp':
            scheduler.step(epoch)
        
        '''训练过程'''
        for batch_idx, rgbd_label_xy in tqdm(enumerate(dataloader_tr),total = len(dataloader_tr),smoothing=0.9):
            x = rgbd_label_xy[0]
            target = rgbd_label_xy[1].long()
            xy = rgbd_label_xy[2]
            x = x.float()
            xy = xy.float()

            input = x.permute(0, 3, 1, 2).contiguous()
            input = input.type(torch.FloatTensor)

            if config.use_gpu:
                input = input.cuda()
                xy = xy.cuda()
                target = target.cuda()

            xy = xy.permute(0, 3, 1, 2).contiguous()

            optimizer.zero_grad()
            model.train()

            output = model(input, gnn_iterations=config.gnn_iterations, k=config.gnn_k, xy=xy, use_gnn=config.use_gnn)
                
            '''config.use_bootstrap_loss=False'''
            if config.use_bootstrap_loss:
                loss_per_pixel = loss.forward(log_softmax(output.float()), target)
                topk, indices = torch.topk(loss_per_pixel.view(output.size()[0], -1),
                                           int((config.crop_size ** 2) * config.bootstrap_rate))
                loss_ = torch.mean(topk)
            else:
                loss_ = loss.forward(log_softmax(output.float()), target)

            loss_.backward()
            optimizer.step()

            batch_loss_avg += loss_.item()

            if batch_idx % interval_to_show == 0 and batch_idx > 0:
                batch_loss_avg /= interval_to_show
                train_losses.append(batch_loss_avg)
                logger.info("E%dB%d Batch loss average: %s", epoch, batch_idx, batch_loss_avg)
                print('\rEpoch:{}, Batch:{}, loss average:{}'.format(epoch, batch_idx, batch_loss_avg))
                batch_loss_avg = 0


        '''训练结束,后续保存参数,并进行测试'''
        batch_idx = len(dataloader_tr)
        logger.info("E%dB%d Saving model...", epoch, batch_idx)

        '''保存模型参数'''
        torch.save(model.state_dict(),log_path +'/save/'+'checkpoint_'+str(epoch)+'.pth')

        '''Evaluation'''
        '''每一次训练完以后,用测试集进行测试,看看分类效果'''
        eval_loss, class_iou, confusion_matrix = eval_set(dataloader_va)
        eval_losses.append(eval_loss)

        '''另一种学习率更新的情况'''
        if config.lr_schedule_type == 'plateau':
            scheduler.step(eval_loss)
        print('Learning ...')
        logger.info("E%dB%d Def learning rate: %s", epoch, batch_idx, get_current_learning_rates()[0])
        print('Epoch{} Def learning rate: {}'.format(epoch, get_current_learning_rates()[0]))
        logger.info("E%dB%d GNN learning rate: %s", epoch, batch_idx, get_current_learning_rates()[1])
        print('Epoch{} GNN learning rate: {}'.format(epoch, get_current_learning_rates()[1]))
        logger.info("E%dB%d Eval loss: %s", epoch, batch_idx, eval_loss)
        print('Epoch{} Eval loss: {}'.format(epoch, eval_loss))
        logger.info("E%dB%d Class IoU:", epoch, batch_idx)
        print('Epoch{} Class IoU:'.format(epoch))
        for cl in range(14):
            logger.info("%+10s: %-10s" % (idx_to_label[cl], class_iou[cl]))
            print('{}:{}'.format(idx_to_label[cl], class_iou[cl]))
        logger.info("Mean IoU: %s", np.mean(class_iou[1:]))
        print("Mean IoU: %.2f"%np.mean(class_iou[1:]))
        logger.info("E%dB%d Confusion matrix:", epoch, batch_idx)
        logger.info(confusion_matrix)
    
    '''所有循环都结束了,此时在日志文件进行记录'''
    logger.info("Finished training!")
    logger.info("Saving model...")
    print('Saving final model...')
    
    '''保存模型从参数'''
    torch.save(model.state_dict(), log_path + '/save/3dgnn_finish.pth')
    
    '''预测集(也是训练集),重新放入网络进行前向传播,评估损失和交并比'''
    eval_loss, class_iou, confusion_matrix = eval_set(dataloader_va)
    
    '''日志文件输出'''
    logger.info("Eval loss: %s", eval_loss)
    logger.info("Class IoU:")
    
    '''14个分类,每一类的IoU(交并比)'''
    for cl in range(14):
        logger.info("%+10s: %-10s" % (idx_to_label[cl], class_iou[cl]))
    
    '''所有类的平均交并比(IoU)'''
    logger.info("Mean IoU: %s", np.mean(class_iou[1:]))

if __name__ == '__main__':
    args = parse_args()
    main(args)
  1. 在Terminal 运行run.py,训练过程如下所示:
    在这里插入图片描述
    下面是一个完整训练周期的输出,对输出的内容的含义进行详细介绍:
    在这里插入图片描述
  • 第一行:Epoch 76 第76个训练周期;Batch 100,训练了100批次(每个批次6张图片);loss average 0.0825 平均损失; 100/242 一共242个批次,现在训练了100个批次;01:03 训练这100个批次花费1分3秒;01:29 剩余训练估计还需耗时1分29秒;1.59it/s 每秒训练1.59个批次(每个批次6张图片)
  • 第二行 同理
  • 第三行 100% 表示训练完毕,一共用时2分32秒,剩余还需要0秒
  • 第四行 100% 表示预测完毕,对242个批次全部完成了预测,耗时1分11秒,剩余还需要耗时0秒,每秒预测3.36个批次(每个批次6张图片)
  • 第五行 Eval time:72.5s 预测(评价/前向传播)消耗时间 72.5秒
  • 第六行 Learning …
  • 后续就是学习率、各类的交并比及平均交并比等
  1. 训练结束后,程序会保存模型参数,如下所示:
    在这里插入图片描述
    3dgnn_finish.pth为最后一次训练后保存的模型,如果想再次训练,可以利用该参数继续进行训练。

到此为止,完成了模型的训练,接下来利用训练好的模型进行预测。

五. 预测

将图片输入网络,输出分类结果,进行预测。因为网络训练时,用到了前向传播函数,因此在run.py文件的基础上进行修改,可以得到预测的源程序。前向传播函数如下所示,其中 dataloader_va 为测试数据集。

eval_loss, class_iou, confusion_matrix = eval_set(dataloader_va)

主要修改内容如下所示:

  1. 将训练过程删除
for epoch in range(1, args.num_epochs + 1):
        batch_loss_avg = 0
        
        '''学习率更新参数'''
        if config.lr_schedule_type == 'exp':
            scheduler.step(epoch)
        
        '''训练过程'''
        for batch_idx, rgbd_label_xy in tqdm(enumerate(dataloader_tr),total = len(dataloader_tr),smoothing=0.9):
            x = rgbd_label_xy[0]
            target = rgbd_label_xy[1].long()
            xy = rgbd_label_xy[2]
            x = x.float()
            xy = xy.float()
            ......
  1. 将分类数据保存为数组,并且改变形状,使得每行代表一张图片
'''将分类数据保存为数组,并且改变形状,使得每行代表一张图片'''
result = pred_arg_max.cpu().numpy().reshape(name_for_txt,307200)
Width = 640
Height = 480
  1. 将数组转换为图片,并且不同的分类用不同的颜色绘制
'''创建空矩阵,用于存放一张图片每个像素的分类数据'''
                Empty_array = np.zeros((Width,Height,3), dtype = np.uint8)
                
                for ii in range(name_for_txt):
                    row_Frame = result[ii] #将一张图片的数据单独保存
                    for w in range(Width):
                        for h in range(Height):
                            '''判断属于哪一类,不一样的类赋予不同的颜色'''
                            if row_Frame[w*Height+h] == 0:
                                #未知类,RGB为0,黑色
                                Empty_array[w,h,0] = 0
                                Empty_array[w,h,1] = 0
                                Empty_array[w,h,2] = 0
                            elif row_Frame[w*Height+h] == 1:
                                #beam类,RGB为 石板灰
                                Empty_array[w,h,0] = 112
                                Empty_array[w,h,1] = 128
                                Empty_array[w,h,2] = 105
                            elif row_Frame[w*Height+h] == 2:
                                #board类,RGB为 马棕色
                                Empty_array[w,h,0] = 139
                                Empty_array[w,h,1] = 69
                                Empty_array[w,h,2] = 19
                            elif row_Frame[w*Height+h] == 3:
                                #bookcase类,RGB为 乌贼墨棕色
                                Empty_array[w,h,0] = 94
                                Empty_array[w,h,1] = 38
                                Empty_array[w,h,2] = 18
                            elif row_Frame[w*Height+h] == 4:
                                #ceiling类,RGB为 
                                Empty_array[w,h,0] = 220
                                Empty_array[w,h,1] = 220
                                Empty_array[w,h,2] = 220
                            elif row_Frame[w*Height+h] == 5:
                                #chair类,RGB为 玫瑰红 
                                Empty_array[w,h,0] = 188
                                Empty_array[w,h,1] = 143
                                Empty_array[w,h,2] = 143
                            elif row_Frame[w*Height+h] == 6:
                                #clutter类,RGB为 镉红
                                Empty_array[w,h,0] = 227
                                Empty_array[w,h,1] = 23
                                Empty_array[w,h,2] = 13
                            elif row_Frame[w*Height+h] == 7:
                                #column类,RGB为 紫色
                                Empty_array[w,h,0] = 160
                                Empty_array[w,h,1] = 32
                                Empty_array[w,h,2] = 240
                            elif row_Frame[w*Height+h] == 8:
                                #door类,RGB为 黄绿色 
                                Empty_array[w,h,0] = 127
                                Empty_array[w,h,1] = 255
                                Empty_array[w,h,2] = 0
                            elif row_Frame[w*Height+h] == 9:
                                #floor类,RGB为 白杏仁
                                Empty_array[w,h,0] = 255
                                Empty_array[w,h,1] = 235
                                Empty_array[w,h,2] = 205
                            elif row_Frame[w*Height+h] == 10:
                                #sofa类,RGB为 棕色
                                Empty_array[w,h,0] = 128
                                Empty_array[w,h,1] = 42
                                Empty_array[w,h,2] = 42
                            elif row_Frame[w*Height+h] == 11:
                                #table类,RGB为 淡黄
                                Empty_array[w,h,0] = 245
                                Empty_array[w,h,1] = 222
                                Empty_array[w,h,2] = 179
                            elif row_Frame[w*Height+h] == 12:
                                #wall类,RGB为 天蓝色 
                                Empty_array[w,h,0] = 240
                                Empty_array[w,h,1] = 255
                                Empty_array[w,h,2] = 255
                            else: 
                                #windows类,RGB为 黄色 
                                Empty_array[w,h,0] = 222
                                Empty_array[w,h,1] = 255
                                Empty_array[w,h,2] = 0
                    #print(Empty_array.shape)
                    # 将数组转化为图片
                    img = Image.fromarray(Empty_array).convert('RGB').rotate(90)  
                    # 将数组保存为图片
                    img.save(eval_path+str(batch_idx*6+ii+1)+'.png') 

全部程序如下所示,并且进行了注释

import cv2
import os
import sys
import time
import numpy as np
import datetime
import logging
import torch
import torch.optim
import torch.nn as nn
from torch.utils.data import DataLoader
from datasets import nyudv2
from models import Model
import config
from tqdm import tqdm
import argparse
from PIL import Image

torch.multiprocessing.set_sharing_strategy('file_system')
torch.backends.cudnn.benchmark = True

'''参数初始化'''
def parse_args():
    '''PARAMETERS'''
    parser = argparse.ArgumentParser('3dgnn')
    
    '''指定循环次数'''
    parser.add_argument('--num_epochs', default=100,type=int,
                        help='Number of epoch')
    
    '''指定批量大小'''
    parser.add_argument('--batchsize', type=int, default=6,
                        help='batch size in training')
                        
    '''修改default,可以使用训练好的参数'''                    
    parser.add_argument('--pretrain', type=str, default='experiment/2021-01-06-14/save/3dgnn_finish.pth',
                        help='Direction for pretrained weight')
    
    '''指定GPU,一般从0开始编号'''
    parser.add_argument('--gpu', type=str, default='0',
                        help='specify gpu device')

    return parser.parse_args()

def main(args):

    '''指定GPU'''
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    
    '''指定日志文件的存放位置'''
    logger = logging.getLogger('3dgnn')
    log_path = './eval/'+ str(datetime.datetime.now().strftime('%Y-%m-%d-%H')).replace(' ', '/') + '/'
    print('log path is:',log_path)
    if not os.path.exists(log_path):
        os.makedirs(log_path)
    hdlr = logging.FileHandler(log_path + 'log.txt')
    formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
    hdlr.setFormatter(formatter)
    logger.addHandler(hdlr)
    logger.setLevel(logging.INFO)
    logger.info("Loading data...")
    print("Loading data...")
    
    '''创建文件夹,保存前向传播后的预测结果'''
    eval_path = './eval/'+ str(datetime.datetime.now().strftime('%Y-%m-%d-%H')).replace(' ', '/') + '/'
    if not os.path.exists(eval_path):
        os.makedirs(eval_path)
    
    '''创建字典,将标签和数字'''
    label_to_idx = {'<UNK>': 0, 'beam': 1, 'board': 2, 'bookcase': 3, 'ceiling': 4, 'chair': 5, 'clutter': 6,
                    'column': 7,
                    'door': 8, 'floor': 9, 'sofa': 10, 'table': 11, 'wall': 12, 'window': 13}

    idx_to_label = {0: '<UNK>', 1: 'beam', 2: 'board', 3: 'bookcase', 4: 'ceiling', 5: 'chair', 6: 'clutter',
                    7: 'column',
                    8: 'door', 9: 'floor', 10: 'sofa', 11: 'table', 12: 'wall', 13: 'window'}


    '''config在该文件夹下有相应的python源代码,Dataset指定了相应的训练集'''
    dataset_tr = nyudv2.Dataset(flip_prob=config.flip_prob,crop_type='Random',crop_size=config.crop_size)
    dataloader_tr = DataLoader(dataset_tr, batch_size=args.batchsize, shuffle=True,
                               num_workers=config.workers_tr, drop_last=False, pin_memory=True)
    
    '''dataset_va是预测集'''
    dataset_va = nyudv2.Dataset(flip_prob=0.0,crop_type='Center',crop_size=config.crop_size)
    dataloader_va = DataLoader(dataset_va, batch_size=args.batchsize, shuffle=False,
                               num_workers=config.workers_va, drop_last=False, pin_memory=True)
    cv2.setNumThreads(config.workers_tr)
    
    '''日志文件需要添加的信息'''
    logger.info("Preparing model...")
    print("Preparing model...")
    
    '''模型初始化'''
    model = Model(config.nclasses, config.mlp_num_layers,config.use_gpu)
    loss = nn.NLLLoss(reduce=not config.use_bootstrap_loss, weight=torch.FloatTensor(config.class_weights))
    
    '''dim表示维度,dim=0,表示行,dim=1,表示列'''
    softmax = nn.Softmax(dim=1)
    log_softmax = nn.LogSoftmax(dim=1)

    '''使用cuda加速'''
    if config.use_gpu:
        model = model.cuda()
        loss = loss.cuda()
        softmax = softmax.cuda()
        log_softmax = log_softmax.cuda()
    
    '''优化器,选择Adam'''
    optimizer = torch.optim.Adam([{'params': model.decoder.parameters()},
                                  {'params': model.gnn.parameters(), 'lr': config.gnn_initial_lr}],
                                 lr=config.base_initial_lr, betas=config.betas, eps=config.eps, weight_decay=config.weight_decay)
    
    '''学习率调整策略,exp指指数衰减调整,plateau指自适应调整'''
    '''https://blog.csdn.net/shanglianlm/article/details/85143614'''
    '''lambda为匿名函数,epoch为目前循环的次数 https://www.cnblogs.com/huangbiquan/p/8030298.html'''
    if config.lr_schedule_type == 'exp':
        lambda1 = lambda epoch: pow((1 - ((epoch - 1) / args.num_epochs)), config.lr_decay)
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1)
    elif config.lr_schedule_type == 'plateau':
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=config.lr_decay, patience=config.lr_patience)
    else:
        print('bad scheduler')
        exit(1)

    '''计算训练参数数量'''
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])

    '''得到目前的学习率'''
    def get_current_learning_rates():
        learning_rates = []
        for param_group in optimizer.param_groups:
            learning_rates.append(param_group['lr'])
        return learning_rates
    
    '''评估/预测,对输入数据进行评估/预测'''
    def eval_set(dataloader):
        model.eval()
        
        '''torch.no_grad()函数使得程序不计算梯度,只进行前向传播,用在预测中正合适'''
        with torch.no_grad():
            loss_sum = 0.0
            
            '''混淆矩阵'''
            confusion_matrix = torch.cuda.FloatTensor(np.zeros(14 ** 2))

            start_time = time.time()

            '''tqdm是进度条模块  batch_idx指目前是第几个batch'''
            for batch_idx, rgbd_label_xy in tqdm(enumerate(dataloader),total = len(dataloader),smoothing=0.9):
                
                #torch.Tensor 类型 [6, 640, 480, 6] 对应数据集中的 rgb_hha 矩阵
                #第一个6是batch大小,640和480是宽和高,第二个6是因为其为rgb与hha的拼接
                x = rgbd_label_xy[0]    
                
                #torch.Tensor 类型 [6, 640, 480, 2] 对应数据集中的 xy 矩阵
                #第一个6是batch大小,640和480是宽和高,2是因为 xy 数据集只有两个维度,且为全零矩阵 
                xy = rgbd_label_xy[2]   
                
                #torch.Tensor 类型 [6, 640, 480] 第一个6是batch大小,640和480是宽和高
                target = rgbd_label_xy[1].long()  
                
                x = x.float()
                xy = xy.float()
                
                '''permute函数用于转换Tensor的维度,contiguous()使得内存是连续的'''
                input = x.permute(0, 3, 1, 2).contiguous()
                xy = xy.permute(0, 3, 1, 2).contiguous()
                if config.use_gpu:
                    input = input.cuda()
                    xy = xy.cuda()
                    target = target.cuda()
                
                '''经过网络,计算输出, 维度为 ([6, 14, 640, 480])'''
                output = model(input, gnn_iterations=config.gnn_iterations, k=config.gnn_k, xy=xy, use_gnn=config.use_gnn)
                
                '''config.use_bootstrap_loss为False'''
                if config.use_bootstrap_loss:
                    loss_per_pixel = loss.forward(log_softmax(output.float()), target)
                    topk, indices = torch.topk(loss_per_pixel.view(output.size()[0], -1),
                                               int((config.crop_size ** 2) * config.bootstrap_rate))
                    loss_ = torch.mean(topk)
                else:
                    '''log_softmax在softmax的结果上再做多一次log运算'''
                    loss_ = loss.forward(log_softmax(output.float()), target)
                loss_sum += loss_

                
                '''pred维度为 ([6, 640, 480, 14]), 连续内存'''
                pred = output.permute(0, 2, 3, 1).contiguous()
                '''源程序没有,专门存放这个batch的大小,后面循环用到'''
                name_for_txt=len(pred)
                '''此时pred维度为 ([1843200, 14]), 其中1843200=6*640*480  config.nclasses=14'''
                pred = pred.view(-1, config.nclasses)
                '''每一行进行softmax运算,相当于对每一个像素的分类进行softmax运算'''
                pred = softmax(pred)
                '''pred_max_val, pred_arg_max都是1843200维,分别存储每个像素最大的分类概率及分类'''
                pred_max_val, pred_arg_max = pred.max(1)
                
                '''将分类数据保存为数组,并且改变形状,使得每行代表一张图片'''
                result = pred_arg_max.cpu().numpy().reshape(name_for_txt,307200)
                Width = 640
                Height = 480
                '''创建空矩阵,用于存放一张图片每个像素的分类数据'''
                Empty_array = np.zeros((Width,Height,3), dtype = np.uint8)
                
                for ii in range(name_for_txt):
                    row_Frame = result[ii] #将一张图片的数据单独保存
                    for w in range(Width):
                        for h in range(Height):
                            '''判断属于哪一类,不一样的类赋予不同的颜色'''
                            if row_Frame[w*Height+h] == 0:
                                #未知类,RGB为0,黑色
                                Empty_array[w,h,0] = 0
                                Empty_array[w,h,1] = 0
                                Empty_array[w,h,2] = 0
                            elif row_Frame[w*Height+h] == 1:
                                #beam类,RGB为 石板灰
                                Empty_array[w,h,0] = 112
                                Empty_array[w,h,1] = 128
                                Empty_array[w,h,2] = 105
                            elif row_Frame[w*Height+h] == 2:
                                #board类,RGB为 马棕色
                                Empty_array[w,h,0] = 139
                                Empty_array[w,h,1] = 69
                                Empty_array[w,h,2] = 19
                            elif row_Frame[w*Height+h] == 3:
                                #bookcase类,RGB为 乌贼墨棕色
                                Empty_array[w,h,0] = 94
                                Empty_array[w,h,1] = 38
                                Empty_array[w,h,2] = 18
                            elif row_Frame[w*Height+h] == 4:
                                #ceiling类,RGB为 
                                Empty_array[w,h,0] = 220
                                Empty_array[w,h,1] = 220
                                Empty_array[w,h,2] = 220
                            elif row_Frame[w*Height+h] == 5:
                                #chair类,RGB为 玫瑰红 
                                Empty_array[w,h,0] = 188
                                Empty_array[w,h,1] = 143
                                Empty_array[w,h,2] = 143
                            elif row_Frame[w*Height+h] == 6:
                                #clutter类,RGB为 镉红
                                Empty_array[w,h,0] = 227
                                Empty_array[w,h,1] = 23
                                Empty_array[w,h,2] = 13
                            elif row_Frame[w*Height+h] == 7:
                                #column类,RGB为 紫色
                                Empty_array[w,h,0] = 160
                                Empty_array[w,h,1] = 32
                                Empty_array[w,h,2] = 240
                            elif row_Frame[w*Height+h] == 8:
                                #door类,RGB为 黄绿色 
                                Empty_array[w,h,0] = 127
                                Empty_array[w,h,1] = 255
                                Empty_array[w,h,2] = 0
                            elif row_Frame[w*Height+h] == 9:
                                #floor类,RGB为 白杏仁
                                Empty_array[w,h,0] = 255
                                Empty_array[w,h,1] = 235
                                Empty_array[w,h,2] = 205
                            elif row_Frame[w*Height+h] == 10:
                                #sofa类,RGB为 棕色
                                Empty_array[w,h,0] = 128
                                Empty_array[w,h,1] = 42
                                Empty_array[w,h,2] = 42
                            elif row_Frame[w*Height+h] == 11:
                                #table类,RGB为 淡黄
                                Empty_array[w,h,0] = 245
                                Empty_array[w,h,1] = 222
                                Empty_array[w,h,2] = 179
                            elif row_Frame[w*Height+h] == 12:
                                #wall类,RGB为 天蓝色 
                                Empty_array[w,h,0] = 240
                                Empty_array[w,h,1] = 255
                                Empty_array[w,h,2] = 255
                            else: 
                                #windows类,RGB为 黄色 
                                Empty_array[w,h,0] = 222
                                Empty_array[w,h,1] = 255
                                Empty_array[w,h,2] = 0
                    #print(Empty_array.shape)
                    # 将数组转化为图片
                    img = Image.fromarray(Empty_array).convert('RGB').rotate(90)  
                    # 将数组保存为图片
                    img.save(eval_path+str(batch_idx*6+ii+1)+'.png')  
                
                """
                #该循环将每张图片的数据分别保存为txt文件
                for ii in range(name_for_txt):
                    f = open(eval_path+'pred_max_val'+str(batch_idx*6+ii)+'.txt','w+')
                    row_Frame = result[ii]
                    #Empty_array.append(row_Frame)
                    for jj in range(307200):
                        strNum = str(row_Frame[jj])
                        f.write(strNum)
                        f.write(' ')
                        f.write('\n')
                    f.close
                """ 
                '''pairs为1843200维'''
                pairs = target.view(-1) * 14 + pred_arg_max.view(-1)
                
                '''计算混淆矩阵'''
                for i in range(14 ** 2):
                    cumu = pairs.eq(i).float().sum()
                    confusion_matrix[i] += cumu.item()

            '''计算预测一共消耗了多长时间'''
            sys.stdout.write(" - Eval time: {:.2f}s \n".format(time.time() - start_time))
            
            '''计算每张照片的平均损失'''
            loss_sum /= len(dataloader)

            confusion_matrix = confusion_matrix.cpu().numpy().reshape((14, 14))
            class_iou = np.zeros(14)
            
            '''将第0类,也就是背景类,对应的行和列设为零'''
            confusion_matrix[0, :] = np.zeros(14)
            confusion_matrix[:, 0] = np.zeros(14)
            
            '''计算交并比'''
            for i in range(1, 14):
                class_iou[i] = confusion_matrix[i, i] / (
                        np.sum(confusion_matrix[i, :]) + np.sum(confusion_matrix[:, i]) - confusion_matrix[i, i])

        return loss_sum.item(), class_iou, confusion_matrix

    '''Training parameter 训练参数'''
    model_to_load = args.pretrain
    logger.info("num_epochs: %d", args.num_epochs)
    print("Number of epochs: %d"%args.num_epochs)
    interval_to_show = 200

    train_losses = []
    eval_losses = []

    '''判断使用原来训练过的模型参数,还是从零开始训练'''
    if model_to_load:
        logger.info("Loading old model...")
        print("Loading old model...")
        model.load_state_dict(torch.load(model_to_load))
    else:
        logger.info("Starting training from scratch...")
        print("Starting training from scratch...")
        
    '''预测集(也是训练集),重新放入网络进行前向传播,评估损失和交并比'''
    eval_loss, class_iou, confusion_matrix = eval_set(dataloader_va)
    
    '''日志文件输出'''
    logger.info("Eval loss: %s", eval_loss)
    logger.info("Class IoU:")
    
    '''14个分类,每一类的IoU(交并比)'''
    for cl in range(14):
        logger.info("%+10s: %-10s" % (idx_to_label[cl], class_iou[cl]))
        print('{}:{}'.format(idx_to_label[cl], class_iou[cl]))
    '''所有类的平均交并比(IoU)'''
    logger.info("Mean IoU: %s", np.mean(class_iou[1:]))
    print(np.mean(class_iou[1:]))    
    
if __name__ == '__main__':
    args = parse_args()
    main(args)

程序运行结果:
在这里插入图片描述
图片输出结果:
在这里插入图片描述
图片1的分类结果如下所示:
在这里插入图片描述
参考:https://blog.csdn.net/qq_38484430/article/details/106584587

  • 3
    点赞
  • 32
    收藏
    觉得还不错? 一键收藏
  • 5
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值