基于tensorflow的3D CNN代码实现

原文地址:http://blog.csdn.net/sinat_31824577/article/details/60325571


结合Udacity 上的 deep learning 公开课 https://cn.udacity.com/course/deep-learning–ud730
3D 卷积神经网络 相比于2D, 多一维仅此而已。原理上与2D 上几乎差不多,但是直接将2D 的网络拿过来直接使用,还是会遇到各种各样的问题,比如说有些库不支持 3D 的卷积运算,caffe就似乎不支持,theano 中没有maxpooling3D , 所以需要自己补充相关的运算。Tensorflow 都很全,在其下搭建3D CNN 很方便。

1. 2D CNN

如下图所示,为经典的lenet-5 模型:conv-pool-conv-pool-conv -fullconnect-softmax,所有的卷积核大小都是5*5,将低层次的像素变化通过卷积来学习层层特征,最后转变成一个84维的向量,最后经过多类回归分析(softmax层),输出类别预测。详细介绍可参考http://blog.csdn.net/xuanyuansen/article/details/41800721
这里写图片描述

2.3D CNN

原理上即参照2D CNN 把各个变量增加一维。

3. Tensorflow 上实现

其中,confusionMatrix 表示用于计算结果的类,可按批处理。
a.3D_cnn.py

#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Thu Feb 23 10:51:28 2017

@author: cdn
"""



import numpy as np
np.random.seed(1234)
import timeit
import os
import matplotlib.pyplot as plt

from sklearn.cross_validation import StratifiedKFold

import tensorflow as tf
from tensorflow.contrib.layers import fully_connected, convolution2d, flatten, dropout
from tensorflow.python.layers.pooling import max_pooling3d
from tensorflow.python.ops.nn import relu,softmax
from tensorflow.python.framework.ops import reset_default_graph
from ConfusionMatrix import ConfusionMatrix


def onehot(t, num_classes):
    out = np.zeros((t.shape[0], num_classes))
    for row, col in enumerate(t):
        out[row, col] = 1
    return out

def load_data(fold_index):

    np.random.seed(1234)
    A_data = np.random.uniform(-0.5,1.5, (500, 65, 52, 51)).astype('float32')   
    B_data = np.random.uniform(0,1, (500, 65, 52, 51)).astype('float32')  # load two classes for classfication

    A_num,sizeX,sizeY,sizeZ = A_data.shape
    B_num,_,_,_ = B_data.shape
    size_input = [1,sizeX,sizeY,sizeZ]
    np.random.seed(1234)
    random_idx = np.random.permutation(A_num+B_num)
    all_data = np.concatenate((A_data,B_data),axis=0)[random_idx]
    labels = np.hstack((np.ones((A_num,)),np.zeros((B_num,))))[random_idx]

    nn =5
    skf = StratifiedKFold(labels,nn)
    train_id = ['']*nn
    test_id = ['']*nn
    a = 0
    for train,test in skf:
        train_id[a] = train
        test_id[a] = test
        a = a+1
    testid = test_id[fold_index]
    validid = test_id[fold_index-1]
    trainid = list(set(train_id[fold_index])-set(validid))
    x_train = all_data[trainid]
    y_train = labels[trainid]
    x_test = all_data[testid]
    y_test = labels[testid]
    x_valid = all_data[validid]
    y_valid = labels[validid] 
    return x_train,y_train,x_test,y_test,x_valid,y_valid,size_input


n_fold = 5
train_accuracy = np.zeros((n_fold,))    
test_accuracy = np.zeros((n_fold,))
valid_accuracy = np.zeros((n_fold,))  

t1_time = timeit.default_timer()
#for fi in range(n_fold):
num_classes = 2
num_filters_conv1 = 10
num_filters_conv2 = 25
num_filters_conv3 = 40
num_filters_conv4 = 40
dense_num = 100
size_conv = 3 # [height, width]
pool_size = 2
batch_size = 5
nb_epoch = 50
fi = 0
X_train,y_train,X_test,y_test,X_val,y_val,size_input = load_data(fi)
X_train = X_train.reshape(X_train.shape[0], 1, X_train.shape[1], X_train.shape[2],X_train.shape[3])  
X_val = X_val.reshape(X_val.shape[0], 1, X_val.shape[1],X_val.shape[2],X_val.shape[3])  
X_test = X_test.reshape(X_test.shape[0], 1, X_test.shape[1], X_test.shape[2],X_test.shape[3])  
print('X_train shape:', X_train.shape)  
print(X_train.shape[0], 'train samples')  
print(X_val.shape[0], 'validate samples')  
print(X_test.shape[0], 'test samples')  



train_accuracy = np.zeros((n_fold,))    
test_accuracy = np.zeros((n_fold,))
valid_accuracy = np.zeros((n_fold,))  
t1_time = timeit.default_timer()
for fi in range(n_fold):  

    print('Now running on fold %d'%(fi+1))
    num_classes = 2
    x_train,y_train,x_test,y_test,x_valid,y_valid,size_input = load_data(fi)
    nchannels,rows,cols,deps = size_input
    x_train = x_train.astype('float32')
    x_train = x_train.reshape((-1,nchannels,rows,cols,deps))
    targets_train = y_train.astype('int32')

    x_valid = x_valid.astype('float32')
    x_valid = x_valid.reshape((-1,nchannels,rows,cols,deps))
    targets_valid = y_valid.astype('int32')

    x_test = x_test.astype('float32')
    x_test = x_test.reshape((-1,nchannels,rows,cols,deps))
    targets_test = y_test.astype('int32')

    # define a simple feed forward neural network

    # hyperameters of the model
    num_classes = 2
    channels = x_train.shape[1]
    height = x_train.shape[2]
    width = x_train.shape[3]
    depth = x_train.shape[4]

    num_filters_conv1 = 10
    num_filters_conv2 = 25
    num_filters_conv3 = 40
    num_filters_conv4 = 40
    kernel_size_conv1 = [3, 3, 3] # [height, width]
    pool_size = [2,2,2]
    stride_conv1 = [1,1,1] # [stride_height, stride_width]
    num_l1 = 100
    # resetting the graph ...
    reset_default_graph()

    # Setting up placeholder, this is where your data enters the graph!
    x_pl = tf.placeholder(tf.float32, [None, channels, height, width, depth])
    l_reshape = tf.transpose(x_pl, [0, 2, 3, 4, 1]) # TensorFlow uses NHWC instead of NCHW
    is_training = tf.placeholder(tf.bool)#used for dropout

    # Building the layers of the neural network
    # we define the variable scope, so we more easily can recognise our variables later
    l_conv1 = convolution2d(l_reshape, num_filters_conv1, kernel_size_conv1, stride_conv1,activation_fn=relu, scope="l_conv1")

    l_maxpool1 = max_pooling3d(l_conv1,pool_size,pool_size)

    l_conv2 = convolution2d(l_maxpool1, num_filters_conv2, kernel_size_conv1, stride_conv1,activation_fn=relu,scope="l_conv2")

    l_maxpool2 = max_pooling3d(l_conv2,pool_size,pool_size)

    l_conv3 = convolution2d(l_maxpool2, num_filters_conv3, kernel_size_conv1, stride_conv1,activation_fn=relu,scope="l_conv3")

    l_maxpool3 = max_pooling3d(l_conv3,pool_size,pool_size)

    l_conv4 = convolution2d(l_maxpool3, num_filters_conv4, kernel_size_conv1, stride_conv1,activation_fn=relu,scope="l_conv4")

    l_flatten = flatten(l_conv4, scope="flatten") # use l_conv1 instead of l_reshape

    l1 = fully_connected(l_flatten, num_l1, activation_fn=relu, scope="l1")

    l1 = dropout(l1, is_training=is_training, scope="dropout")

    y = fully_connected(l1, num_classes, activation_fn=softmax, scope="y")

    # y_ is a placeholder variable taking on the value of the target batch.
    y_ = tf.placeholder(tf.float32, [None, num_classes])

    # computing cross entropy per sample
    cross_entropy = -tf.reduce_sum(y_ * tf.log(y+1e-8), reduction_indices=[1])

    # averaging over samples
    cross_entropy = tf.reduce_mean(cross_entropy)

    # defining our optimizer
    optimizer = tf.train.AdamOptimizer(learning_rate=0.001)

    # applying the gradients
    train_op = optimizer.minimize(cross_entropy)

    #Test the forward pass
#    x = np.random.normal(0,1, (45, 1,65, 52, 51)).astype('float32') #dummy data

    # restricting memory usage, TensorFlow is greedy and will use all memory otherwise
    gpu_opts = tf.GPUOptions(per_process_gpu_memory_fraction=0.2)
    # initialize the Session
    sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_opts))
    sess.run(tf.global_variables_initializer())
#    res = sess.run(fetches=[y], feed_dict={x_pl: x})
#    res = sess.run(fetches=[y], feed_dict={x_pl: x, is_training: False}) # for when using dropout
#    print "y", res[0].shape

    #Training Loop
    from confusionmatrix import ConfusionMatrix
    batch_size = 10
    num_epochs = 50
    num_samples_train = x_train.shape[0]
    num_batches_train = num_samples_train // batch_size
    num_samples_valid = x_valid.shape[0]
    num_batches_valid = num_samples_valid // batch_size

    train_acc, train_loss = [], []
    valid_acc, valid_loss = [], []
    test_acc, test_loss = [], []
    cur_loss = 0
    loss = []

    try:
        for epoch in range(num_epochs):
            #Forward->Backprob->Update params
            cur_loss = 0
            for i in range(num_batches_train):
                idx = range(i*batch_size, (i+1)*batch_size)
                x_batch = x_train[idx]
                target_batch = targets_train[idx]
#                feed_dict_train = {x_pl: x_batch, y_: onehot(target_batch, num_classes)}
                feed_dict_train = {x_pl: x_batch, y_: onehot(target_batch, num_classes), is_training: True}
                fetches_train = [train_op, cross_entropy]
                res = sess.run(fetches=fetches_train, feed_dict=feed_dict_train)
                batch_loss = res[1] #this will do the complete backprob pass
                cur_loss += batch_loss
            loss += [cur_loss/batch_size]

            confusion_valid = ConfusionMatrix(num_classes)
            confusion_train = ConfusionMatrix(num_classes)

            for i in range(num_batches_train):
                idx = range(i*batch_size, (i+1)*batch_size)
                x_batch = x_train[idx]
                targets_batch = targets_train[idx]
                # what to feed our accuracy op
#                feed_dict_eval_train = {x_pl: x_batch}
                feed_dict_eval_train = {x_pl: x_batch, is_training: False}
                # deciding which parts to fetch
                fetches_eval_train = [y]
                # running the validation
                res = sess.run(fetches=fetches_eval_train, feed_dict=feed_dict_eval_train)
                # collecting and storing predictions
                net_out = res[0] 
                preds = np.argmax(net_out, axis=-1) 
                confusion_train.batch_add(targets_batch, preds)

            confusion_valid = ConfusionMatrix(num_classes)
            for i in range(num_batches_valid):
                idx = range(i*batch_size, (i+1)*batch_size)
                x_batch = x_valid[idx]
                targets_batch = targets_valid[idx]
                # what to feed our accuracy op
#                feed_dict_eval_train = {x_pl: x_batch}
                feed_dict_eval_train = {x_pl: x_batch, is_training: False}
                # deciding which parts to fetch
                fetches_eval_train = [y]
                # running the validation
                res = sess.run(fetches=fetches_eval_train, feed_dict=feed_dict_eval_train)
                # collecting and storing predictions
                net_out = res[0]
                preds = np.argmax(net_out, axis=-1) 

                confusion_valid.batch_add(targets_batch, preds)

            train_acc_cur = confusion_train.accuracy()
            valid_acc_cur = confusion_valid.accuracy()

            train_acc += [train_acc_cur]
            valid_acc += [valid_acc_cur]
            print "Epoch %i : Train Loss %e , Train acc %f,  Valid acc %f " \
            % (epoch+1, loss[-1], train_acc_cur, valid_acc_cur)
    except KeyboardInterrupt:
        pass


    #get test set score
    confusion_test = ConfusionMatrix(num_classes)

    # what to feed our accuracy op
#    feed_dict_eval_train = {x_pl: x_test}
    feed_dict_eval_train = {x_pl: x_test, is_training: False}
    # deciding which parts to fetch
    fetches_eval_train = [y]
    # running the validation
    res = sess.run(fetches=fetches_eval_train, feed_dict=feed_dict_eval_train)
    # collecting and storing predictions
    net_out = res[0] 
    preds = np.argmax(net_out, axis=-1) 
    confusion_test.batch_add(targets_test, preds)
    print "\nTest set Acc:  %f" %(confusion_test.accuracy())
    test_acc = confusion_test.accuracy()

    epoch = np.arange(len(train_acc))
    plt.figure()
    plt.plot(epoch,train_acc,'r',epoch,valid_acc,'b')
    plt.legend(['Train Acc','Val Acc'])
    plt.xlabel('Epochs'), plt.ylabel('Acc'), plt.ylim([0.2,1.03])
    plt.show()
    train_accuracy[fi] = train_acc[-1]
    test_accuracy[fi] = test_acc
    valid_accuracy[fi] = valid_acc[-1]

print '\nMean accuray of test set: %f %%' %(np.mean(test_accuracy)*100)

t2_time = timeit.default_timer()
print(('The code for Tensorflow ' +os.path.split(__file__)[1] +' ran for %.2fm' % ((t2_time - t1_time) / 60.))) 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
  • 231
  • 232
  • 233
  • 234
  • 235
  • 236
  • 237
  • 238
  • 239
  • 240
  • 241
  • 242
  • 243
  • 244
  • 245
  • 246
  • 247
  • 248
  • 249
  • 250
  • 251
  • 252
  • 253
  • 254
  • 255
  • 256
  • 257
  • 258
  • 259
  • 260
  • 261
  • 262
  • 263
  • 264
  • 265
  • 266
  • 267
  • 268
  • 269
  • 270
  • 271
  • 272
  • 273
  • 274
  • 275
  • 276
  • 277
  • 278
  • 279
  • 280
  • 281
  • 282
  • 283
  • 284
  • 285
  • 286
  • 287
  • 288
  • 289
  • 290
  • 291
  • 292
  • 293
  • 294
  • 295
  • 296
  • 297
  • 298
  • 299
  • 300
  • 301
  • 302
  • 303
  • 304
  • 305

b. ConfusionMatrix.py

import numpy as np


class ConfusionMatrix:
    """
       Simple confusion matrix class
       row is the true class, column is the predicted class
    """
    def __init__(self, num_classes, class_names=None):
        self.n_classes = num_classes
        if class_names is None:
            self.class_names = map(str, range(num_classes))
        else:
            self.class_names = class_names

        # find max class_name and pad
        max_len = max(map(len, self.class_names))
        self.max_len = max_len
        for idx, name in enumerate(self.class_names):
            if len(self.class_names) < max_len:
                self.class_names[idx] = name + " "*(max_len-len(name))

        self.mat = np.zeros((num_classes,num_classes),dtype='int')

    def __str__(self):
        # calucate row and column sums
        col_sum = np.sum(self.mat, axis=1)
        row_sum = np.sum(self.mat, axis=0)

        s = []

        mat_str = self.mat.__str__()
        mat_str = mat_str.replace('[','').replace(']','').split('\n')

        for idx, row in enumerate(mat_str):
            if idx == 0:
                pad = " "
            else:
                pad = ""
            class_name = self.class_names[idx]
            class_name = " " + class_name + " |"
            row_str = class_name + pad + row
            row_str += " |" + str(col_sum[idx])
            s.append(row_str)

        row_sum = [(self.max_len+4)*" "+" ".join(map(str, row_sum))]
        hline = [(1+self.max_len)*" "+"-"*len(row_sum[0])]

        s = hline + s + hline + row_sum

        # add linebreaks
        s_out = [line+'\n' for line in s]
        return "".join(s_out)

    def batch_add(self, targets, preds):
        assert targets.shape == preds.shape
        assert len(targets) == len(preds)
        assert max(targets) < self.n_classes
        assert max(preds) < self.n_classes
        targets = targets.flatten()
        preds = preds.flatten()
        for i in range(len(targets)):
                self.mat[targets[i], preds[i]] += 1

    def get_errors(self):
        tp = np.asarray(np.diag(self.mat).flatten(),dtype='float')
        fn = np.asarray(np.sum(self.mat, axis=1).flatten(),dtype='float') - tp
        fp = np.asarray(np.sum(self.mat, axis=0).flatten(),dtype='float') - tp
        tn = np.asarray(np.sum(self.mat)*np.ones(self.n_classes).flatten(),
                        dtype='float') - tp - fn - fp
        return tp, fn, fp, tn

    def accuracy(self):
        """
        Calculates global accuracy
        :return: accuracy
        :example: >>> conf = ConfusionMatrix(3)
                  >>> conf.batchAdd([0,0,1],[0,0,2])
                  >>> print conf.accuracy()
        """
        tp, _, _, _ = self.get_errors()
        n_samples = np.sum(self.mat)
        return np.sum(tp) / n_samples

    def sensitivity(self):
        tp, tn, fp, fn = self.get_errors()
        res = tp / (tp + fn)
        res = res[~np.isnan(res)]
        return res

    def specificity(self):
        tp, tn, fp, fn = self.get_errors()
        res = tn / (tn + fp)
        res = res[~np.isnan(res)]
        return res

    def positive_predictive_value(self):
        tp, tn, fp, fn = self.get_errors()
        res = tp / (tp + fp)
        res = res[~np.isnan(res)]
        return res

    def negative_predictive_value(self):
        tp, tn, fp, fn = self.get_errors()
        res = tn / (tn + fn)
        res = res[~np.isnan(res)]
        return res

    def false_positive_rate(self):
        tp, tn, fp, fn = self.get_errors()
        res = fp / (fp + tn)
        res = res[~np.isnan(res)]
        return res

    def false_discovery_rate(self):
        tp, tn, fp, fn = self.get_errors()
        res = fp / (tp + fp)
        res = res[~np.isnan(res)]
        return res

    def F1(self):
        tp, tn, fp, fn = self.get_errors()
        res = (2*tp) / (2*tp + fp + fn)
        res = res[~np.isnan(res)]
        return res

    def matthews_correlation(self):
        tp, tn, fp, fn = self.get_errors()
        numerator = tp*tn - fp*fn
        denominator = np.sqrt((tp + fp)*(tp + fn)*(tn + fp)*(tn + fn))
        res = numerator / denominator
        res = res[~np.isnan(res)]
        return res
  • 1
    点赞
  • 42
    收藏
    觉得还不错? 一键收藏
  • 6
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值