之前看了U-net的代码,不过没有实际运行相应的代码,读相应的博客也了解了一些初学者关于U-net的问题:
1.U-net的套路结构,以及论文中的结构
2.U-net的数据增强方式
3.U-net的代码实现方式
4.U-net的损失函数
如果大家有读论文的习惯,那大家首先关注的应该是这篇论文的应用场合以及相对于以前工作的优点。
这里有一篇博客说明了U-net的作用以及特点https://blog.csdn.net/u012931582/article/details/70215756
U-net属于FCN框架,FCN是输入和输出都是图像,没有全连接层。较浅的高分辨率层用来解决像素定位的问题,较深的层用来解决像素分类的问题。属于端到端的学习,图像风格转换以及图像超分辨率都是这类框架。
大致说明了U-net结构的意义,我们来说明文初U-net的三个问题:
1.U-net的套路结构:
相信大家在网上的论文都会看到这样的图:
第一层:
可以看到输入是572*572*1的图像,其实原始图片输入应该是512*512*1,在3*3卷积的过程中可以发现图像尺寸在不断的变小,论文中使用的卷积方式是'VALID',而不是'SAME'。如果我们希望边缘的像素点也可以被准确分割的情况下,U-Net使用了镜像操作(Overlay-tile Strategy)来解决该问题。镜像操作即是给输入图像加入一个对称的边,那么边的宽度是多少呢?一个比较好的策略是通过感受野确定。因为有效卷积是会降低Feature Map分辨率的,但是我们希望 的图像的边界点能够保留到最后一层Feature Map。所以我们需要通过加边的操作增加图像的分辨率,增加的尺寸即是感受野的大小,也就是说每条边界增加感受野的一半作为镜像边。
该图片从知乎引用。
根据图中所示的压缩路径的网络架构,我们可以计算其感受野:
这也就是为什么U-Net的输入数据是 的。572的卷积的另外一个好处是每次降采样操作的Feature Map的尺寸都是偶数,这个值也是和网络结构密切相关的。相关博客:https://zhuanlan.zhihu.com/p/43927696
左半部分为例:这是很正常的CNN的结构,不过这里的一个单元是,第一次conv+RELU将channel值倍增,第二次conv+RELU将channel不变。
最底层:,这里依然是和之前一样的,两次卷积,将channel倍增和维持不变。
右半部分:
蓝色框的转换是通过进行,进行的特点是边长倍增,但是通道数减少一倍,这是反卷积操作(实际上是转置卷积,论文见神经网络特征可视化),但是这里多了白色框的部分,白色的框和灰色箭头代表 将之前的特征和现有特征拼接起来(操作是concat,densenet有类似操作),原因是在不断的降采样的过程中,信息虽然抽象程度越来越高,但是信息也是在不断的损失的,通过将之前层的信息,结合可以更好的判断分割,例如原图和最后一层的concat,只进行上采样过程可能只能分割大致区域(毕竟降采样到了很小的尺寸上),结合原始图像可以很好的定位分割位置。
这里输入输出的尺寸差异实际上是由于在卷积过程中使用'VALID'的方式。如果改成'SAME'输入和输出就可以一致,最后的代码说明就是这种方式。不过同尺寸的方式由于没有代码实验,效果怎样是不确定的,不过现在大都是按照输入输出一致进行网络结构设定的。
2.图像数据的增多:由于生物图像的特殊性,形变后的组织也是符合相应的组织特点的,如下图
图像扭曲的论文:http://faculty.cs.tamu.edu/schaefer/research/mls.pdf
当然噪声也是可以加进去的。
其实这里还有一个细节需要注意(参看https://zhuanlan.zhihu.com/p/43927696):U-net的损失函数,有时分割图像是这样的,细胞间是紧密相连的,所以边缘是非常难以探测的,这时需要对损失函数进行设定。
那么该怎样设计损失函数来让模型有分离边界的能力呢?U-Net使用的是带边界权值的损失函数:
其中 是$$softmax$$损失函数, 是像素点的标签值, 是像素点的权值,目的是为了给图像中贴近边界点的像素更高的权值。
其中 是平衡类别比例的权值, 是像素点到距离其最近的细胞的距离, 则是像素点到距离其第二近的细胞的距离。 和 是常数值,在实验中 , 。
当然对于某些图像就没有必要了,例如
代码说明(有一些注释):
class Unet:
def __init__(self):
print('New U-net Network')
self.input_image = None
self.input_label = None
self.cast_image = None
self.cast_label = None
self.keep_prob = None
self.lamb = None
self.result_expand = None
self.loss, self.loss_mean, self.loss_all, self.train_step = [None] * 4
self.prediction, self.correct_prediction, self.accuracy = [None] * 3
self.result_conv = {}
self.result_relu = {}
self.result_maxpool = {}
self.result_from_contract_layer = {}
self.w = {}
self.b = {}
def init_w(self, shape, name):
with tf.name_scope('init_w'):
stddev = tf.sqrt(x=2 / (shape[0] * shape[1] * shape[2]))
# stddev = 0.01
w = tf.Variable(initial_value=tf.truncated_normal(shape=shape, stddev=stddev, dtype=tf.float32), name=name)
tf.add_to_collection(name='loss', value=tf.contrib.layers.l2_regularizer(self.lamb)(w))
return w
@staticmethod
def init_b(shape, name):
with tf.name_scope('init_b'):
return tf.Variable(initial_value=tf.random_normal(shape=shape, dtype=tf.float32), name=name)
@staticmethod
def copy_and_crop_and_merge(result_from_contract_layer, result_from_upsampling):
# result_from_contract_layer_shape = tf.shape(result_from_contract_layer)
# result_from_upsampling_shape = tf.shape(result_from_upsampling)
# result_from_contract_layer_crop = \
# tf.slice(
# input_=result_from_contract_layer,
# begin=[
# 0,
# (result_from_contract_layer_shape[1] - result_from_upsampling_shape[1]) // 2,
# (result_from_contract_layer_shape[2] - result_from_upsampling_shape[2]) // 2,
# 0
# ],
# size=[
# result_from_upsampling_shape[0],
# result_from_upsampling_shape[1],
# result_from_upsampling_shape[2],
# result_from_upsampling_shape[3]
# ]
# )
result_from_contract_layer_crop = result_from_contract_layer
return tf.concat(values=[result_from_contract_layer_crop, result_from_upsampling], axis=-1)
def set_up_unet(self, batch_size):
# input
with tf.name_scope('input'):
# learning_rate = tf.train.exponential_decay()
self.input_image = tf.placeholder(
dtype=tf.float32, shape=[batch_size, INPUT_IMG_WIDE, INPUT_IMG_WIDE, INPUT_IMG_CHANNEL], name='input_images'
)
# self.cast_image = tf.reshape(
# tensor=self.input_image,
# shape=[batch_size, INPUT_IMG_WIDE, INPUT_IMG_WIDE, INPUT_IMG_CHANNEL]
# )
# for softmax_cross_entropy_with_logits(labels=self.input_label, logits=self.prediction, name='loss')
# using one-hot
# self.input_label = tf.placeholder(
# dtype=tf.uint8, shape=[OUTPUT_IMG_WIDE, OUTPUT_IMG_WIDE], name='input_labels'
# )
# self.cast_label = tf.reshape(
# tensor=self.input_label,
# shape=[batch_size, OUTPUT_IMG_WIDE, OUTPUT_IMG_HEIGHT]
# )
# for sparse_softmax_cross_entropy_with_logits(labels=self.input_label, logits=self.prediction, name='loss')
# not using one-hot coding
self.input_label = tf.placeholder(
dtype=tf.int32, shape=[batch_size, OUTPUT_IMG_WIDE, OUTPUT_IMG_WIDE], name='input_labels'
)
self.keep_prob = tf.placeholder(dtype=tf.float32, name='keep_prob')
self.lamb = tf.placeholder(dtype=tf.float32, name='lambda')
# layer 1
with tf.name_scope('layer_1'):
# conv_1
self.w[1] = self.init_w(shape=[3, 3, INPUT_IMG_CHANNEL, 64], name='w_1')
self.b[1] = self.init_b(shape=[64], name='b_1')
result_conv_1 = tf.nn.conv2d(
input=self.input_image, filter=self.w[1],
strides=[1, 1, 1, 1], padding='SAME', name='conv_1')
result_relu_1 = tf.nn.relu(tf.nn.bias_add(result_conv_1, self.b[1], name='add_bias'), name='relu_1')
# conv_2
self.w[2] = self.init_w(shape=[3, 3, 64, 64], name='w_2')
self.b[2] = self.init_b(shape=[64], name='b_2')
result_conv_2 = tf.nn.conv2d(
input=result_relu_1, filter=self.w[2],
strides=[1, 1, 1, 1], padding='SAME', name='conv_2')
result_relu_2 = tf.nn.relu(tf.nn.bias_add(result_conv_2, self.b[2], name='add_bias'), name='relu_2')
self.result_from_contract_layer[1] = result_relu_2 # 该层结果临时保存, 供上采样使用
# maxpool
result_maxpool = tf.nn.max_pool(
value=result_relu_2, ksize=[1, 2, 2, 1],
strides=[1, 2, 2, 1], padding='VALID', name='maxpool')
# dropout
result_dropout = tf.nn.dropout(x=result_maxpool, keep_prob=self.keep_prob)
# layer 2
with tf.name_scope('layer_2'):
# conv_1
self.w[3] = self.init_w(shape=[3, 3, 64, 128], name='w_3')
self.b[3] = self.init_b(shape=[128], name='b_3')
result_conv_1 = tf.nn.conv2d(
input=result_dropout, filter=self.w[3],
strides=[1, 1, 1, 1], padding='SAME', name='conv_1')
result_relu_1 = tf.nn.relu(tf.nn.bias_add(result_conv_1, self.b[3], name='add_bias'), name='relu_1')
# conv_2
self.w[4] = self.init_w(shape=[3, 3, 128, 128], name='w_4')
self.b[4] = self.init_b(shape=[128], name='b_4')
result_conv_2 = tf.nn.conv2d(
input=result_relu_1, filter=self.w[4],
strides=[1, 1, 1, 1], padding='SAME', name='conv_2')
result_relu_2 = tf.nn.relu(tf.nn.bias_add(result_conv_2, self.b[4], name='add_bias'), name='relu_2')
self.result_from_contract_layer[2] = result_relu_2 # 该层结果临时保存, 供上采样使用
# maxpool
result_maxpool = tf.nn.max_pool(
value=result_relu_2, ksize=[1, 2, 2, 1],
strides=[1, 2, 2, 1], padding='VALID', name='maxpool')
# dropout
result_dropout = tf.nn.dropout(x=result_maxpool, keep_prob=self.keep_prob)
# layer 3
with tf.name_scope('layer_3'):
# conv_1
self.w[5] = self.init_w(shape=[3, 3, 128, 256], name='w_5')
self.b[5] = self.init_b(shape=[256], name='b_5')
result_conv_1 = tf.nn.conv2d(
input=result_dropout, filter=self.w[5],
strides=[1, 1, 1, 1], padding='SAME', name='conv_1')
result_relu_1 = tf.nn.relu(tf.nn.bias_add(result_conv_1, self.b[5], name='add_bias'), name='relu_1')
# conv_2
self.w[6] = self.init_w(shape=[3, 3, 256, 256], name='w_6')
self.b[6] = self.init_b(shape=[256], name='b_6')
result_conv_2 = tf.nn.conv2d(
input=result_relu_1, filter=self.w[6],
strides=[1, 1, 1, 1], padding='SAME', name='conv_2')
result_relu_2 = tf.nn.relu(tf.nn.bias_add(result_conv_2, self.b[6], name='add_bias'), name='relu_2')
self.result_from_contract_layer[3] = result_relu_2 # 该层结果临时保存, 供上采样使用
# maxpool
result_maxpool = tf.nn.max_pool(
value=result_relu_2, ksize=[1, 2, 2, 1],
strides=[1, 2, 2, 1], padding='VALID', name='maxpool')
# dropout
result_dropout = tf.nn.dropout(x=result_maxpool, keep_prob=self.keep_prob)
# layer 4
with tf.name_scope('layer_4'):
# conv_1
self.w[7] = self.init_w(shape=[3, 3, 256, 512], name='w_7')
self.b[7] = self.init_b(shape=[512], name='b_7')
result_conv_1 = tf.nn.conv2d(
input=result_dropout, filter=self.w[7],
strides=[1, 1, 1, 1], padding='SAME', name='conv_1')
result_relu_1 = tf.nn.relu(tf.nn.bias_add(result_conv_1, self.b[7], name='add_bias'), name='relu_1')
# conv_2
self.w[8] = self.init_w(shape=[3, 3, 512, 512], name='w_8')
self.b[8] = self.init_b(shape=[512], name='b_8')
result_conv_2 = tf.nn.conv2d(
input=result_relu_1, filter=self.w[8],
strides=[1, 1, 1, 1], padding='SAME', name='conv_2')
result_relu_2 = tf.nn.relu(tf.nn.bias_add(result_conv_2, self.b[8], name='add_bias'), name='relu_2')
self.result_from_contract_layer[4] = result_relu_2 # 该层结果临时保存, 供上采样使用
# maxpool
result_maxpool = tf.nn.max_pool(
value=result_relu_2, ksize=[1, 2, 2, 1],
strides=[1, 2, 2, 1], padding='VALID', name='maxpool')
# dropout
result_dropout = tf.nn.dropout(x=result_maxpool, keep_prob=self.keep_prob)
# layer 5 (bottom)
with tf.name_scope('layer_5'):
# conv_1
self.w[9] = self.init_w(shape=[3, 3, 512, 1024], name='w_9')
self.b[9] = self.init_b(shape=[1024], name='b_9')
result_conv_1 = tf.nn.conv2d(
input=result_dropout, filter=self.w[9],
strides=[1, 1, 1, 1], padding='SAME', name='conv_1')
result_relu_1 = tf.nn.relu(tf.nn.bias_add(result_conv_1, self.b[9], name='add_bias'), name='relu_1')
# conv_2
self.w[10] = self.init_w(shape=[3, 3, 1024, 1024], name='w_10')
self.b[10] = self.init_b(shape=[1024], name='b_10')
result_conv_2 = tf.nn.conv2d(
input=result_relu_1, filter=self.w[10],
strides=[1, 1, 1, 1], padding='SAME', name='conv_2')
result_relu_2 = tf.nn.relu(tf.nn.bias_add(result_conv_2, self.b[10], name='add_bias'), name='relu_2')
# up sample
self.w[11] = self.init_w(shape=[2, 2, 512, 1024], name='w_11')
self.b[11] = self.init_b(shape=[512], name='b_11')
result_up = tf.nn.conv2d_transpose(
value=result_relu_2, filter=self.w[11],
output_shape=[batch_size, 64, 64, 512],
strides=[1, 2, 2, 1], padding='VALID', name='Up_Sample')
result_relu_3 = tf.nn.relu(tf.nn.bias_add(result_up, self.b[11], name='add_bias'), name='relu_3')
# dropout
result_dropout = tf.nn.dropout(x=result_relu_3, keep_prob=self.keep_prob)
# layer 6
with tf.name_scope('layer_6'):
# copy, crop and merge
result_merge = self.copy_and_crop_and_merge(
result_from_contract_layer=self.result_from_contract_layer[4], result_from_upsampling=result_dropout)
# print(result_merge)
# conv_1
self.w[12] = self.init_w(shape=[3, 3, 1024, 512], name='w_12')
self.b[12] = self.init_b(shape=[512], name='b_12')
result_conv_1 = tf.nn.conv2d(
input=result_merge, filter=self.w[12],
strides=[1, 1, 1, 1], padding='SAME', name='conv_1')
result_relu_1 = tf.nn.relu(tf.nn.bias_add(result_conv_1, self.b[12], name='add_bias'), name='relu_1')
# conv_2
self.w[13] = self.init_w(shape=[3, 3, 512, 512], name='w_10')
self.b[13] = self.init_b(shape=[512], name='b_10')
result_conv_2 = tf.nn.conv2d(
input=result_relu_1, filter=self.w[13],
strides=[1, 1, 1, 1], padding='SAME', name='conv_2')
result_relu_2 = tf.nn.relu(tf.nn.bias_add(result_conv_2, self.b[13], name='add_bias'), name='relu_2')
# print(result_relu_2.shape[1])
# up sample
self.w[14] = self.init_w(shape=[2, 2, 256, 512], name='w_11')
self.b[14] = self.init_b(shape=[256], name='b_11')
result_up = tf.nn.conv2d_transpose(
value=result_relu_2, filter=self.w[14],
output_shape=[batch_size, 128, 128, 256],
strides=[1, 2, 2, 1], padding='VALID', name='Up_Sample')
result_relu_3 = tf.nn.relu(tf.nn.bias_add(result_up, self.b[14], name='add_bias'), name='relu_3')
# dropout
result_dropout = tf.nn.dropout(x=result_relu_3, keep_prob=self.keep_prob)
# layer 7
with tf.name_scope('layer_7'):
# copy, crop and merge
result_merge = self.copy_and_crop_and_merge(
result_from_contract_layer=self.result_from_contract_layer[3], result_from_upsampling=result_dropout)
# conv_1
self.w[15] = self.init_w(shape=[3, 3, 512, 256], name='w_12')
self.b[15] = self.init_b(shape=[256], name='b_12')
result_conv_1 = tf.nn.conv2d(
input=result_merge, filter=self.w[15],
strides=[1, 1, 1, 1], padding='SAME', name='conv_1')
result_relu_1 = tf.nn.relu(tf.nn.bias_add(result_conv_1, self.b[15], name='add_bias'), name='relu_1')
# conv_2
self.w[16] = self.init_w(shape=[3, 3, 256, 256], name='w_10')
self.b[16] = self.init_b(shape=[256], name='b_10')
result_conv_2 = tf.nn.conv2d(
input=result_relu_1, filter=self.w[16],
strides=[1, 1, 1, 1], padding='SAME', name='conv_2')
result_relu_2 = tf.nn.relu(tf.nn.bias_add(result_conv_2, self.b[16], name='add_bias'), name='relu_2')
# up sample
self.w[17] = self.init_w(shape=[2, 2, 128, 256], name='w_11')
self.b[17] = self.init_b(shape=[128], name='b_11')
result_up = tf.nn.conv2d_transpose(
value=result_relu_2, filter=self.w[17],
output_shape=[batch_size, 256, 256, 128],
strides=[1, 2, 2, 1], padding='VALID', name='Up_Sample')
result_relu_3 = tf.nn.relu(tf.nn.bias_add(result_up, self.b[17], name='add_bias'), name='relu_3')
# dropout
result_dropout = tf.nn.dropout(x=result_relu_3, keep_prob=self.keep_prob)
# layer 8
with tf.name_scope('layer_8'):
# copy, crop and merge
result_merge = self.copy_and_crop_and_merge(
result_from_contract_layer=self.result_from_contract_layer[2], result_from_upsampling=result_dropout)
# conv_1
self.w[18] = self.init_w(shape=[3, 3, 256, 128], name='w_12')
self.b[18] = self.init_b(shape=[128], name='b_12')
result_conv_1 = tf.nn.conv2d(
input=result_merge, filter=self.w[18],
strides=[1, 1, 1, 1], padding='SAME', name='conv_1')
result_relu_1 = tf.nn.relu(tf.nn.bias_add(result_conv_1, self.b[18], name='add_bias'), name='relu_1')
# conv_2
self.w[19] = self.init_w(shape=[3, 3, 128, 128], name='w_10')
self.b[19] = self.init_b(shape=[128], name='b_10')
result_conv_2 = tf.nn.conv2d(
input=result_relu_1, filter=self.w[19],
strides=[1, 1, 1, 1], padding='SAME', name='conv_2')
result_relu_2 = tf.nn.relu(tf.nn.bias_add(result_conv_2, self.b[19], name='add_bias'), name='relu_2')
# up sample
self.w[20] = self.init_w(shape=[2, 2, 64, 128], name='w_11')
self.b[20] = self.init_b(shape=[64], name='b_11')
result_up = tf.nn.conv2d_transpose(
value=result_relu_2, filter=self.w[20],
output_shape=[batch_size, 512, 512, 64],
strides=[1, 2, 2, 1], padding='VALID', name='Up_Sample')
result_relu_3 = tf.nn.relu(tf.nn.bias_add(result_up, self.b[20], name='add_bias'), name='relu_3')
# dropout
result_dropout = tf.nn.dropout(x=result_relu_3, keep_prob=self.keep_prob)
# layer 9
with tf.name_scope('layer_9'):
# copy, crop and merge
result_merge = self.copy_and_crop_and_merge(
result_from_contract_layer=self.result_from_contract_layer[1], result_from_upsampling=result_dropout)
# conv_1
self.w[21] = self.init_w(shape=[3, 3, 128, 64], name='w_12')
self.b[21] = self.init_b(shape=[64], name='b_12')
result_conv_1 = tf.nn.conv2d(
input=result_merge, filter=self.w[21],
strides=[1, 1, 1, 1], padding='SAME', name='conv_1')
result_relu_1 = tf.nn.relu(tf.nn.bias_add(result_conv_1, self.b[21], name='add_bias'), name='relu_1')
# conv_2
self.w[22] = self.init_w(shape=[3, 3, 64, 64], name='w_10')
self.b[22] = self.init_b(shape=[64], name='b_10')
result_conv_2 = tf.nn.conv2d(
input=result_relu_1, filter=self.w[22],
strides=[1, 1, 1, 1], padding='SAME', name='conv_2')
result_relu_2 = tf.nn.relu(tf.nn.bias_add(result_conv_2, self.b[22], name='add_bias'), name='relu_2')
# convolution to [batch_size, OUTPIT_IMG_WIDE, OUTPUT_IMG_HEIGHT, CLASS_NUM]
self.w[23] = self.init_w(shape=[1, 1, 64, CLASS_NUM], name='w_11')
self.b[23] = self.init_b(shape=[CLASS_NUM], name='b_11')
result_conv_3 = tf.nn.conv2d(
input=result_relu_2, filter=self.w[23],
strides=[1, 1, 1, 1], padding='VALID', name='conv_3')
# self.prediction = tf.nn.relu(tf.nn.bias_add(result_conv_3, self.b[23], name='add_bias'), name='relu_3')
# self.prediction = tf.nn.sigmoid(x=tf.nn.bias_add(result_conv_3, self.b[23], name='add_bias'), name='sigmoid_1')
self.prediction = tf.nn.bias_add(result_conv_3, self.b[23], name='add_bias')
# print(self.prediction)
# print(self.input_label)
# softmax loss
with tf.name_scope('softmax_loss'):
# using one-hot
# self.loss = \
# tf.nn.softmax_cross_entropy_with_logits(labels=self.cast_label, logits=self.prediction, name='loss')
# not using one-hot
self.loss = \
tf.nn.sparse_softmax_cross_entropy_with_logits(labels=self.input_label, logits=self.prediction, name='loss')
self.loss_mean = tf.reduce_mean(self.loss)
tf.add_to_collection(name='loss', value=self.loss_mean)
self.loss_all = tf.add_n(inputs=tf.get_collection(key='loss'))
# accuracy
with tf.name_scope('accuracy'):
# using one-hot
# self.correct_prediction = tf.equal(tf.argmax(self.prediction, axis=3), tf.argmax(self.cast_label, axis=3))
# not using one-hot
self.correct_prediction = \
tf.equal(tf.argmax(input=self.prediction, axis=3, output_type=tf.int32), self.input_label)
self.correct_prediction = tf.cast(self.correct_prediction, tf.float32)
self.accuracy = tf.reduce_mean(self.correct_prediction)
# Gradient Descent
with tf.name_scope('Gradient_Descent'):
self.train_step = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(self.loss_all)
def train(self):
# import cv2
# import numpy as np
# ckpt_path = os.path.join(FLAGS.model_dir, "model.ckpt")
# all_parameters_saver = tf.train.Saver()
# # import numpy as np
# # mydata = DataProcess(INPUT_IMG_HEIGHT, INPUT_IMG_WIDE)
# # imgs_train, imgs_mask_train = mydata.load_my_train_data()
# my_set_image = cv2.imread('../data_set/train.tif', flags=0)
# my_set_label = cv2.imread('../data_set/label.tif', flags=0)
# my_set_image.astype('float32')
# my_set_label[my_set_label <= 128] = 0
# my_set_label[my_set_label > 128] = 1
# my_set_image = np.reshape(a=my_set_image, newshape=(1, INPUT_IMG_WIDE, INPUT_IMG_HEIGHT, INPUT_IMG_CHANNEL))
# my_set_label = np.reshape(a=my_set_label, newshape=(1, OUTPUT_IMG_WIDE, OUTPUT_IMG_HEIGHT))
# # cv2.imshow('image', my_set_image)
# # cv2.imshow('label', my_set_label * 100)
# # cv2.waitKey(0)
# with tf.Session() as sess: # 开始一个会话
# sess.run(tf.global_variables_initializer())
# sess.run(tf.local_variables_initializer())
# for epoch in range(10):
# lo, acc = sess.run(
# [self.loss_mean, self.accuracy],
# feed_dict={
# self.input_image: my_set_image, self.input_label: my_set_label,
# self.keep_prob: 1.0, self.lamb: 0.004}
# )
# # print('num %d, loss: %.6f and accuracy: %.6f' % (epoch, lo, acc))
# print('num %d, loss: %.6f and accuracy: %.6f' % (epoch, lo, acc))
# sess.run(
# [self.train_step],
# feed_dict={
# self.input_image: my_set_image, self.input_label: my_set_label,
# self.keep_prob: 0.6, self.lamb: 0.004}
# )
# all_parameters_saver.save(sess=sess, save_path=ckpt_path)
# print("Done training")
train_file_path = os.path.join(FLAGS.data_dir, TRAIN_SET_NAME)
train_image_filename_queue = tf.train.string_input_producer(
string_tensor=tf.train.match_filenames_once(train_file_path), num_epochs=EPOCH_NUM, shuffle=True)
ckpt_path = CHECK_POINT_PATH
train_images, train_labels = read_image_batch(train_image_filename_queue, TRAIN_BATCH_SIZE)
tf.summary.scalar("loss", self.loss_mean)
tf.summary.scalar('accuracy', self.accuracy)
merged_summary = tf.summary.merge_all()
all_parameters_saver = tf.train.Saver()
with tf.Session() as sess: # 开始一个会话
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
summary_writer = tf.summary.FileWriter(FLAGS.tb_dir, sess.graph)
tf.summary.FileWriter(FLAGS.model_dir, sess.graph)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
try:
epoch = 1
while not coord.should_stop():
# Run training steps or whatever
# print('epoch ' + str(epoch))
example, label = sess.run([train_images, train_labels]) # 在会话中取出image和label
# print(label)
lo, acc, summary_str = sess.run(
[self.loss_mean, self.accuracy, merged_summary],
feed_dict={
self.input_image: example, self.input_label: label, self.keep_prob: 1.0,
self.lamb: 0.004}
)
summary_writer.add_summary(summary_str, epoch)
# print('num %d, loss: %.6f and accuracy: %.6f' % (epoch, lo, acc))
if epoch % 10 == 0:
print('num %d, loss: %.6f and accuracy: %.6f' % (epoch, lo, acc))
sess.run(
[self.train_step],
feed_dict={
self.input_image: example, self.input_label: label, self.keep_prob: 0.6,
self.lamb: 0.004}
)
epoch += 1
except tf.errors.OutOfRangeError:
print('Done training -- epoch limit reached')
finally:
# When done, ask the threads to stop.
all_parameters_saver.save(sess=sess, save_path=ckpt_path)
coord.request_stop()
# coord.request_stop()
coord.join(threads)
print("Done training")
def validate(self):
# import cv2
# import numpy as np
# ckpt_path = os.path.join(FLAGS.model_dir, "model.ckpt")
# # mydata = DataProcess(INPUT_IMG_HEIGHT, INPUT_IMG_WIDE)
# # imgs_train, imgs_mask_train = mydata.load_my_train_data()
# all_parameters_saver = tf.train.Saver()
# my_set_image = cv2.imread('../data_set/train.tif', flags=0)
# my_set_label = cv2.imread('../data_set/label.tif', flags=0)
# my_set_image.astype('float32')
# my_set_label[my_set_label <= 128] = 0
# my_set_label[my_set_label > 128] = 1
# with tf.Session() as sess:
# sess.run(tf.global_variables_initializer())
# sess.run(tf.local_variables_initializer())
# all_parameters_saver.restore(sess=sess, save_path=ckpt_path)
# image, acc = sess.run(
# fetches=[self.prediction, self.accuracy],
# feed_dict={
# self.input_image: my_set_image, self.input_label: my_set_label,
# self.keep_prob: 1.0, self.lamb: 0.004}
# )
# image = np.argmax(a=image[0], axis=2).astype('uint8') * 255
# # cv2.imshow('predict', image)
# # cv2.imshow('o', np.asarray(a=image[0], dtype=np.uint8) * 100)
# # cv2.waitKey(0)
# cv2.imwrite(filename=os.path.join(FLAGS.model_dir, 'predict.jpg'), img=image)
# print(acc)
# print("Done test, predict image has been saved to %s" % (os.path.join(FLAGS.model_dir, 'predict.jpg')))
validation_file_path = os.path.join(FLAGS.data_dir, VALIDATION_SET_NAME)
validation_image_filename_queue = tf.train.string_input_producer(
string_tensor=tf.train.match_filenames_once(validation_file_path), num_epochs=1, shuffle=True)
ckpt_path = CHECK_POINT_PATH
validation_images, validation_labels = read_image_batch(validation_image_filename_queue, VALIDATION_BATCH_SIZE)
# tf.summary.scalar("loss", self.loss_mean)
# tf.summary.scalar('accuracy', self.accuracy)
# merged_summary = tf.summary.merge_all()
all_parameters_saver = tf.train.Saver()
with tf.Session() as sess: # 开始一个会话
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
# summary_writer = tf.summary.FileWriter(FLAGS.tb_dir, sess.graph)
# tf.summary.FileWriter(FLAGS.model_dir, sess.graph)
all_parameters_saver.restore(sess=sess, save_path=ckpt_path)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
try:
epoch = 1
while not coord.should_stop():
# Run training steps or whatever
# print('epoch ' + str(epoch))
example, label = sess.run([validation_images, validation_labels]) # 在会话中取出image和label
# print(label)
lo, acc = sess.run(
[self.loss_mean, self.accuracy],
feed_dict={
self.input_image: example, self.input_label: label, self.keep_prob: 1.0,
self.lamb: 0.004}
)
# summary_writer.add_summary(summary_str, epoch)
# print('num %d, loss: %.6f and accuracy: %.6f' % (epoch, lo, acc))
if epoch % 1 == 0:
print('num %d, loss: %.6f and accuracy: %.6f' % (epoch, lo, acc))
epoch += 1
except tf.errors.OutOfRangeError:
print('Done validating -- epoch limit reached')
finally:
# When done, ask the threads to stop.
coord.request_stop()
# coord.request_stop()
coord.join(threads)
print('Done validating')
def test(self):
import cv2
test_file_path = os.path.join(FLAGS.data_dir, TEST_SET_NAME)
test_image_filename_queue = tf.train.string_input_producer(
string_tensor=tf.train.match_filenames_once(test_file_path), num_epochs=1, shuffle=True)
ckpt_path = CHECK_POINT_PATH
test_images, test_labels = read_image_batch(test_image_filename_queue, TEST_BATCH_SIZE)
# tf.summary.scalar("loss", self.loss_mean)
# tf.summary.scalar('accuracy', self.accuracy)
# merged_summary = tf.summary.merge_all()
all_parameters_saver = tf.train.Saver()
with tf.Session() as sess: # 开始一个会话
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
# summary_writer = tf.summary.FileWriter(FLAGS.tb_dir, sess.graph)
# tf.summary.FileWriter(FLAGS.model_dir, sess.graph)
all_parameters_saver.restore(sess=sess, save_path=ckpt_path)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
sum_acc = 0.0
try:
epoch = 0
while not coord.should_stop():
# Run training steps or whatever
# print('epoch ' + str(epoch))
example, label = sess.run([test_images, test_labels]) # 在会话中取出image和label
# print(label)
image, acc = sess.run(
[tf.argmax(input=self.prediction, axis=3), self.accuracy],
feed_dict={
self.input_image: example, self.input_label: label,
self.keep_prob: 1.0, self.lamb: 0.004
}
)
sum_acc += acc
epoch += 1
cv2.imwrite(os.path.join(PREDICT_SAVED_DIRECTORY, '%d.jpg' % epoch), image[0] * 255)
if epoch % 1 == 0:
print('num %d accuracy: %.6f' % (epoch, acc))
except tf.errors.OutOfRangeError:
print('Done testing -- epoch limit reached \n Average accuracy: %.2f%%' % (sum_acc / TEST_SET_SIZE * 100))
finally:
# When done, ask the threads to stop.
coord.request_stop()
# coord.request_stop()
coord.join(threads)
print('Done testing')
def predict(self):
import cv2
import glob
import numpy as np
# TODO 不应该这样写,应该直接读图片预测,而不是从tfrecord读取,因为顺序变了,无法对应
predict_file_path = glob.glob(os.path.join(ORIGIN_PREDICT_DIRECTORY, '*.tif'))
print(len(predict_file_path))
if not os.path.lexists(PREDICT_SAVED_DIRECTORY):
os.mkdir(PREDICT_SAVED_DIRECTORY)
ckpt_path = CHECK_POINT_PATH
all_parameters_saver = tf.train.Saver()
with tf.Session() as sess: # 开始一个会话
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
# summary_writer = tf.summary.FileWriter(FLAGS.tb_dir, sess.graph)
# tf.summary.FileWriter(FLAGS.model_dir, sess.graph)
all_parameters_saver.restore(sess=sess, save_path=ckpt_path)
for index, image_path in enumerate(predict_file_path):
# image = cv2.imread(image_path, flags=0)
image = np.reshape(a=cv2.imread(image_path, flags=0), newshape=(1, INPUT_IMG_WIDE, INPUT_IMG_HEIGHT, INPUT_IMG_CHANNEL))
predict_image = sess.run(
tf.argmax(input=self.prediction, axis=3),
feed_dict={
self.input_image: image,
self.keep_prob: 1.0, self.lamb: 0.004
}
)
cv2.imwrite(os.path.join(PREDICT_SAVED_DIRECTORY, '%d.jpg' % index), predict_image[0] * 255)
print('Done prediction')