如何将LDGCNN的T-SNE应用于DGCNN&PointNet

我自己是通过修改LDGCNN到DGCNN来实现DGCNN特征的T-SNE可视化,同样的操作也可用于PointNet
先去github下载好ldgcnn和DGCNN代码

复制部分

  1. 首先在data文件夹新建:extracted_feature文件夹,其中包含两个txt文本,分别为:test_files.txt,train_files.txt,内容分别是:data/extracted_feature/test_global_feature.h5data/extracted_feature/train_global_feature.h5
  2. 复制LDGCNN下的VisionProcess文件夹;
  3. 复制tsne_visualization.py文件

修改部分

1. evaluate.py修改:

将行23

parser.add_argument('--model_path', default='log/model.ckpt', help='model checkpoint file path [default: log/model.ckpt]')

修改为

parser.add_argument('--model_path', default='log/dgcnn_model.ckpt', help='model checkpoint file path [default: log/model.ckpt]')

将行65左右

# simple model
        pred, end_points = MODEL.get_model(pointclouds_pl, is_training_pl)
        loss = MODEL.get_loss(pred, labels_pl,end_points)

修改为

# simple model
        pred, layers = MODEL.get_model(pointclouds_pl, is_training_pl)
        loss = MODEL.get_loss(pred, labels_pl)

2. tsne_visualization.py修改:

ldgcnn的t-sne可视化的是特征,其中有将data从point cloud变成features的代码。要注意的是行40,别把for循环注释没了

3. train.py修改:

行14左右增加

sys.path.append(os.path.join(BASE_DIR, 'VisionProcess'))
from FileIO import FileIO

行33左右增加

NAME_MODEL = ''

行69左右增加

# Feature files, which are generated after training the whole network.
# The extracted feature files are utilized to train the classifier.
path = 'data/extracted_feature'
TRAIN_FILES_CLS = provider.getDataFiles( \
    os.path.join(BASE_DIR, path + '/train_files.txt'))
TEST_FILES_CLS = provider.getDataFiles(\
    os.path.join(BASE_DIR, path + '/test_files.txt'))

给eval_one_epoch函数增加返回值:

 return total_correct / float(total_seen)
train()函数修改,将
with tf.device('/gpu:'+str(GPU_INDEX)):
            pointclouds_pl, end_points = MODEL.placeholder_inputs(BATCH_SIZE, NUM_POINT)
            is_training_pl = tf.placeholder(tf.bool, shape=())
            print(is_training_pl)
修改为
with tf.device('/gpu:'+str(GPU_INDEX)):
            pointclouds_pl, labels_pl = MODEL.placeholder_inputs(BATCH_SIZE, NUM_POINT)
            is_training_pl = tf.placeholder(tf.bool, shape=())
            print(is_training_pl)
train()函数修改,将
 for epoch in range(MAX_EPOCH):
            log_string('**** EPOCH %03d ****' % (epoch))
            sys.stdout.flush()
             
            train_one_epoch(sess, ops, train_writer)
            eval_one_epoch(sess, ops, test_writer)
            
            # Save the variables to disk.
            if epoch % 10 == 0:
                save_path = saver.save(sess, os.path.join(LOG_DIR, "model.ckpt"))
                log_string("Model saved in file: %s" % save_path)

修改为
best_accuracy = 0.0
        for epoch in range(MAX_EPOCH):
            log_string('**** EPOCH %03d ****' % (epoch))
            sys.stdout.flush()
            
            train_one_epoch(sess, ops, train_writer)
            accuracy = eval_one_epoch(sess, ops, test_writer)
            # Save the network that achieves the best validation accuracy.
            # There are only training set and validation set for ModelNet40. 
            # Previous researchers report their best accuracy rather than 
            # final accuracy because they also regard the testing set as 
            # validation set.
            if accuracy > best_accuracy:
                best_accuracy = accuracy
                save_path = saver.save(sess, os.path.join(
                        LOG_DIR, FLAGS.model+ NAME_MODEL +"_model.ckpt"))
                log_string("Best accuracy, model saved in file: %s" % save_path)
        
        # Save the extracted global feature 
        save_global_feature(sess, ops, saver, layers)

ldgcnn作者说以前的人都是这样做的:选择最优模型而不是最后模型。可以这样改一改试试,看看结果有没有变好。。

增加新的函数save_global_feature
def save_global_feature(sess, ops, saver, end_points):
    feature_name = 'global_feature'
    file_name_vec = ['train_' + feature_name, 'test_' + feature_name]
    Files_vec = [TRAIN_FILES, TEST_FILES]
    #Restore variables that achieves the best validation accuracy from the disk.
    saver.restore(sess, os.path.join(LOG_DIR, FLAGS.model+
                                         str(NAME_MODEL)+ "_model.ckpt")) 
    log_string("Model restored.") 
    is_training = False
    # Extract the features from training set and validation set.
    for r in range(2):
        file_name = file_name_vec[r]
        Files = Files_vec[r]
        global_feature_vec = np.array([])
        label_vec = np.array([])
        for fn in range(len(Files)):
            log_string('----'+str(fn)+'----')
            current_data, current_label = provider.loadDataFile(Files[fn])
            current_data = current_data[:,0:NUM_POINT,:]
            current_label = np.squeeze(current_label)
            print(current_data.shape)
            
            file_size = current_data.shape[0]
            num_batches = file_size // BATCH_SIZE
            print(file_size)
            
            for batch_idx in range(num_batches):
                start_idx = batch_idx * BATCH_SIZE
                end_idx = (batch_idx+1) * BATCH_SIZE
                # Input the point cloud and labels to the graph.
                feed_dict = {ops['pointclouds_pl']: current_data[start_idx:end_idx, :, :],
                             ops['labels_pl']: current_label[start_idx:end_idx],
                             ops['is_training_pl']: is_training}
                # Extract the global features from the input batch data.
                global_feature = np.squeeze(end_points[feature_name].eval(
                    feed_dict=feed_dict,session=sess))
                
                if label_vec.shape[0] == 0:
                    global_feature_vec = global_feature
                    label_vec = current_label[start_idx:end_idx]
                else:
                    global_feature_vec = np.concatenate([global_feature_vec, global_feature])
                    label_vec = np.concatenate([label_vec, current_label[start_idx:end_idx]])      
        # Save all global features to the disk.
        FileIO.write_h5('data/extracted_feature/' + file_name + '.h5', global_feature_vec, label_vec)

4. models文件下里的dgcnn.py修改

  # MLP on global point cloud vector
  net = tf.reshape(net, [batch_size, -1]) 
  net = tf_util.fully_connected(net, 512, bn=True, is_training=is_training,
                                scope='fc1', bn_decay=bn_decay)
  net = tf_util.dropout(net, keep_prob=0.5, is_training=is_training,
                         scope='dp1')
  net = tf_util.fully_connected(net, 256, bn=True, is_training=is_training,
                                scope='fc2', bn_decay=bn_decay)
  net = tf_util.dropout(net, keep_prob=0.5, is_training=is_training,
                        scope='dp2')
  net = tf_util.fully_connected(net, 40, activation_fn=None, scope='fc3')
  return net, end_points

修改为

  # MLP on global point cloud vector
  layers = {}
  net = tf.reshape(net, [batch_size, -1]) 
  layers['global_feature'] = net
  net = tf_util.fully_connected(net, 512, bn=True, is_training=is_training,
                                scope='fc1', bn_decay=bn_decay)
  layers['fc1'] = net
  net = tf_util.dropout(net, keep_prob=0.5, is_training=is_training,
                         scope='dp1')
  net = tf_util.fully_connected(net, 256, bn=True, is_training=is_training,
                                scope='fc2', bn_decay=bn_decay)
  layers['fc2'] = net
  net = tf_util.dropout(net, keep_prob=0.5, is_training=is_training,
                        scope='dp2')
  net = tf_util.fully_connected(net, 40, activation_fn=None, scope='fc3')
  layers['fc3'] = net

  return net, layers

并将接下来的get_loss函数修改(end_points根本就没用到嘛)

def get_loss(pred, label,end_points):

改为

def get_loss(pred, label):

运行部分

训练:

python train.py

测试:

python evaluate.py

T-SNE可视化:

python tsne_visualization.py

Tensorboard调用:

 tensorboard --logdir log

结果展示

data T-SNE
在这里插入图片描述
feature T-SNE
在这里插入图片描述

完事

  • 1
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值