U-net结构及代码注释

之前看了U-net的代码,不过没有实际运行相应的代码,读相应的博客也了解了一些初学者关于U-net的问题:

1.U-net的套路结构,以及论文中的结构

2.U-net的数据增强方式

3.U-net的代码实现方式

4.U-net的损失函数

如果大家有读论文的习惯,那大家首先关注的应该是这篇论文的应用场合以及相对于以前工作的优点。

这里有一篇博客说明了U-net的作用以及特点https://blog.csdn.net/u012931582/article/details/70215756

U-net属于FCN框架,FCN是输入和输出都是图像,没有全连接层。较浅的高分辨率层用来解决像素定位的问题,较深的层用来解决像素分类的问题。属于端到端的学习,图像风格转换以及图像超分辨率都是这类框架。

 

大致说明了U-net结构的意义,我们来说明文初U-net的三个问题:

1.U-net的套路结构:

相信大家在网上的论文都会看到这样的图:

第一层:

可以看到输入是572*572*1的图像,其实原始图片输入应该是512*512*1,在3*3卷积的过程中可以发现图像尺寸在不断的变小,论文中使用的卷积方式是'VALID',而不是'SAME'。如果我们希望边缘的像素点也可以被准确分割的情况下,U-Net使用了镜像操作(Overlay-tile Strategy)来解决该问题。镜像操作即是给输入图像加入一个对称的边,那么边的宽度是多少呢?一个比较好的策略是通过感受野确定。因为有效卷积是会降低Feature Map分辨率的,但是我们希望 512\times512 的图像的边界点能够保留到最后一层Feature Map。所以我们需要通过加边的操作增加图像的分辨率,增加的尺寸即是感受野的大小,也就是说每条边界增加感受野的一半作为镜像边。

该图片从知乎引用。

根据图中所示的压缩路径的网络架构,我们可以计算其感受野:

\text{rf} = (((0 \times2 +2 +2)\times2 +2 +2)\times2 +2 +2)\times2 +2 +2 = 60 \tag{1}

这也就是为什么U-Net的输入数据是 572\times572 的。572的卷积的另外一个好处是每次降采样操作的Feature Map的尺寸都是偶数,这个值也是和网络结构密切相关的。相关博客:https://zhuanlan.zhihu.com/p/43927696

左半部分为例:这是很正常的CNN的结构,不过这里的一个单元是,第一次conv+RELU将channel值倍增,第二次conv+RELU将channel不变。

最底层:,这里依然是和之前一样的,两次卷积,将channel倍增和维持不变。

右半部分:

蓝色框的转换是通过进行,进行的特点是边长倍增,但是通道数减少一倍,这是反卷积操作(实际上是转置卷积,论文见神经网络特征可视化),但是这里多了白色框的部分,白色的框和灰色箭头代表 将之前的特征和现有特征拼接起来(操作是concat,densenet有类似操作),原因是在不断的降采样的过程中,信息虽然抽象程度越来越高,但是信息也是在不断的损失的,通过将之前层的信息,结合可以更好的判断分割,例如原图和最后一层的concat,只进行上采样过程可能只能分割大致区域(毕竟降采样到了很小的尺寸上),结合原始图像可以很好的定位分割位置。

这里输入输出的尺寸差异实际上是由于在卷积过程中使用'VALID'的方式。如果改成'SAME'输入和输出就可以一致,最后的代码说明就是这种方式。不过同尺寸的方式由于没有代码实验,效果怎样是不确定的,不过现在大都是按照输入输出一致进行网络结构设定的。

 

2.图像数据的增多:由于生物图像的特殊性,形变后的组织也是符合相应的组织特点的,如下图

这里写图片描述

这里写图片描述

 

图像扭曲的论文:http://faculty.cs.tamu.edu/schaefer/research/mls.pdf

当然噪声也是可以加进去的。

其实这里还有一个细节需要注意(参看https://zhuanlan.zhihu.com/p/43927696):U-net的损失函数,有时分割图像是这样的,细胞间是紧密相连的,所以边缘是非常难以探测的,这时需要对损失函数进行设定。

那么该怎样设计损失函数来让模型有分离边界的能力呢?U-Net使用的是带边界权值的损失函数:

其中 p_{\ell(\mathbf{x})}(\mathbf{x}) 是$$softmax$$损失函数, \ell: \Omega \rightarrow \{1,...,K\} 是像素点的标签值, w: \Omega \in \mathbb{R} 是像素点的权值,目的是为了给图像中贴近边界点的像素更高的权值。

其中 w_c: \Omega \in \mathbb{R} 是平衡类别比例的权值, d_1: \Omega \in \mathbb{R} 是像素点到距离其最近的细胞的距离, d_2: \Omega \in \mathbb{R} 则是像素点到距离其第二近的细胞的距离。 w_0 和 \sigma 是常数值,在实验中 w_0 = 10, \sigma\approx 5 。

当然对于某些图像就没有必要了,例如TestResult

 

代码说明(有一些注释):


class Unet:

	def __init__(self):
		print('New U-net Network')
		self.input_image = None
		self.input_label = None
		self.cast_image = None
		self.cast_label = None
		self.keep_prob = None
		self.lamb = None
		self.result_expand = None
		self.loss, self.loss_mean, self.loss_all, self.train_step = [None] * 4
		self.prediction, self.correct_prediction, self.accuracy = [None] * 3
		self.result_conv = {}
		self.result_relu = {}
		self.result_maxpool = {}
		self.result_from_contract_layer = {}
		self.w = {}
		self.b = {}

	def init_w(self, shape, name):
		with tf.name_scope('init_w'):
			stddev = tf.sqrt(x=2 / (shape[0] * shape[1] * shape[2]))
			# stddev = 0.01
			w = tf.Variable(initial_value=tf.truncated_normal(shape=shape, stddev=stddev, dtype=tf.float32), name=name)
			tf.add_to_collection(name='loss', value=tf.contrib.layers.l2_regularizer(self.lamb)(w))
			return w

	@staticmethod
	def init_b(shape, name):
		with tf.name_scope('init_b'):
			return tf.Variable(initial_value=tf.random_normal(shape=shape, dtype=tf.float32), name=name)

	@staticmethod
	def copy_and_crop_and_merge(result_from_contract_layer, result_from_upsampling):
		# result_from_contract_layer_shape = tf.shape(result_from_contract_layer)
		# result_from_upsampling_shape = tf.shape(result_from_upsampling)
		# result_from_contract_layer_crop = \
		# 	tf.slice(
		# 		input_=result_from_contract_layer,
		# 		begin=[
		# 			0,
		# 			(result_from_contract_layer_shape[1] - result_from_upsampling_shape[1]) // 2,
		# 			(result_from_contract_layer_shape[2] - result_from_upsampling_shape[2]) // 2,
		# 			0
		# 		],
		# 		size=[
		# 			result_from_upsampling_shape[0],
		# 			result_from_upsampling_shape[1],
		# 			result_from_upsampling_shape[2],
		# 			result_from_upsampling_shape[3]
		# 		]
		# 	)
		result_from_contract_layer_crop = result_from_contract_layer
		return tf.concat(values=[result_from_contract_layer_crop, result_from_upsampling], axis=-1)

	def set_up_unet(self, batch_size):
		# input
		with tf.name_scope('input'):
			# learning_rate = tf.train.exponential_decay()
			self.input_image = tf.placeholder(
				dtype=tf.float32, shape=[batch_size, INPUT_IMG_WIDE, INPUT_IMG_WIDE, INPUT_IMG_CHANNEL], name='input_images'
			)
			# self.cast_image = tf.reshape(
			# 	tensor=self.input_image,
			# 	shape=[batch_size, INPUT_IMG_WIDE, INPUT_IMG_WIDE, INPUT_IMG_CHANNEL]
			# )

			# for softmax_cross_entropy_with_logits(labels=self.input_label, logits=self.prediction, name='loss')
			# using one-hot
			# self.input_label = tf.placeholder(
			# 	dtype=tf.uint8, shape=[OUTPUT_IMG_WIDE, OUTPUT_IMG_WIDE], name='input_labels'
			# )
			# self.cast_label = tf.reshape(
			# 	tensor=self.input_label,
			# 	shape=[batch_size, OUTPUT_IMG_WIDE, OUTPUT_IMG_HEIGHT]
			# )

			# for sparse_softmax_cross_entropy_with_logits(labels=self.input_label, logits=self.prediction, name='loss')
			# not using one-hot coding
			self.input_label = tf.placeholder(
				dtype=tf.int32, shape=[batch_size, OUTPUT_IMG_WIDE, OUTPUT_IMG_WIDE], name='input_labels'
			)
			self.keep_prob = tf.placeholder(dtype=tf.float32, name='keep_prob')
			self.lamb = tf.placeholder(dtype=tf.float32, name='lambda')

		# layer 1
		with tf.name_scope('layer_1'):
			# conv_1
			self.w[1] = self.init_w(shape=[3, 3, INPUT_IMG_CHANNEL, 64], name='w_1')
			self.b[1] = self.init_b(shape=[64], name='b_1')
			result_conv_1 = tf.nn.conv2d(
				input=self.input_image, filter=self.w[1],
				strides=[1, 1, 1, 1], padding='SAME', name='conv_1')
			result_relu_1 = tf.nn.relu(tf.nn.bias_add(result_conv_1, self.b[1], name='add_bias'), name='relu_1')

			# conv_2
			self.w[2] = self.init_w(shape=[3, 3, 64, 64], name='w_2')
			self.b[2] = self.init_b(shape=[64], name='b_2')
			result_conv_2 = tf.nn.conv2d(
				input=result_relu_1, filter=self.w[2],
				strides=[1, 1, 1, 1], padding='SAME', name='conv_2')
			result_relu_2 = tf.nn.relu(tf.nn.bias_add(result_conv_2, self.b[2], name='add_bias'), name='relu_2')
			self.result_from_contract_layer[1] = result_relu_2  # 该层结果临时保存, 供上采样使用

			# maxpool
			result_maxpool = tf.nn.max_pool(
				value=result_relu_2, ksize=[1, 2, 2, 1],
				strides=[1, 2, 2, 1], padding='VALID', name='maxpool')

			# dropout
			result_dropout = tf.nn.dropout(x=result_maxpool, keep_prob=self.keep_prob)

		# layer 2
		with tf.name_scope('layer_2'):
			# conv_1
			self.w[3] = self.init_w(shape=[3, 3, 64, 128], name='w_3')
			self.b[3] = self.init_b(shape=[128], name='b_3')
			result_conv_1 = tf.nn.conv2d(
				input=result_dropout, filter=self.w[3],
				strides=[1, 1, 1, 1], padding='SAME', name='conv_1')
			result_relu_1 = tf.nn.relu(tf.nn.bias_add(result_conv_1, self.b[3], name='add_bias'), name='relu_1')

			# conv_2
			self.w[4] = self.init_w(shape=[3, 3, 128, 128], name='w_4')
			self.b[4] = self.init_b(shape=[128], name='b_4')
			result_conv_2 = tf.nn.conv2d(
				input=result_relu_1, filter=self.w[4],
				strides=[1, 1, 1, 1], padding='SAME', name='conv_2')
			result_relu_2 = tf.nn.relu(tf.nn.bias_add(result_conv_2, self.b[4], name='add_bias'), name='relu_2')
			self.result_from_contract_layer[2] = result_relu_2  # 该层结果临时保存, 供上采样使用

			# maxpool
			result_maxpool = tf.nn.max_pool(
				value=result_relu_2, ksize=[1, 2, 2, 1],
				strides=[1, 2, 2, 1], padding='VALID', name='maxpool')

			# dropout
			result_dropout = tf.nn.dropout(x=result_maxpool, keep_prob=self.keep_prob)

		# layer 3
		with tf.name_scope('layer_3'):
			# conv_1
			self.w[5] = self.init_w(shape=[3, 3, 128, 256], name='w_5')
			self.b[5] = self.init_b(shape=[256], name='b_5')
			result_conv_1 = tf.nn.conv2d(
				input=result_dropout, filter=self.w[5],
				strides=[1, 1, 1, 1], padding='SAME', name='conv_1')
			result_relu_1 = tf.nn.relu(tf.nn.bias_add(result_conv_1, self.b[5], name='add_bias'), name='relu_1')

			# conv_2
			self.w[6] = self.init_w(shape=[3, 3, 256, 256], name='w_6')
			self.b[6] = self.init_b(shape=[256], name='b_6')
			result_conv_2 = tf.nn.conv2d(
				input=result_relu_1, filter=self.w[6],
				strides=[1, 1, 1, 1], padding='SAME', name='conv_2')
			result_relu_2 = tf.nn.relu(tf.nn.bias_add(result_conv_2, self.b[6], name='add_bias'), name='relu_2')
			self.result_from_contract_layer[3] = result_relu_2  # 该层结果临时保存, 供上采样使用

			# maxpool
			result_maxpool = tf.nn.max_pool(
				value=result_relu_2, ksize=[1, 2, 2, 1],
				strides=[1, 2, 2, 1], padding='VALID', name='maxpool')

			# dropout
			result_dropout = tf.nn.dropout(x=result_maxpool, keep_prob=self.keep_prob)

		# layer 4
		with tf.name_scope('layer_4'):
			# conv_1
			self.w[7] = self.init_w(shape=[3, 3, 256, 512], name='w_7')
			self.b[7] = self.init_b(shape=[512], name='b_7')
			result_conv_1 = tf.nn.conv2d(
				input=result_dropout, filter=self.w[7],
				strides=[1, 1, 1, 1], padding='SAME', name='conv_1')
			result_relu_1 = tf.nn.relu(tf.nn.bias_add(result_conv_1, self.b[7], name='add_bias'), name='relu_1')

			# conv_2
			self.w[8] = self.init_w(shape=[3, 3, 512, 512], name='w_8')
			self.b[8] = self.init_b(shape=[512], name='b_8')
			result_conv_2 = tf.nn.conv2d(
				input=result_relu_1, filter=self.w[8],
				strides=[1, 1, 1, 1], padding='SAME', name='conv_2')
			result_relu_2 = tf.nn.relu(tf.nn.bias_add(result_conv_2, self.b[8], name='add_bias'), name='relu_2')
			self.result_from_contract_layer[4] = result_relu_2  # 该层结果临时保存, 供上采样使用

			# maxpool
			result_maxpool = tf.nn.max_pool(
				value=result_relu_2, ksize=[1, 2, 2, 1],
				strides=[1, 2, 2, 1], padding='VALID', name='maxpool')

			# dropout
			result_dropout = tf.nn.dropout(x=result_maxpool, keep_prob=self.keep_prob)

		# layer 5 (bottom)
		with tf.name_scope('layer_5'):
			# conv_1
			self.w[9] = self.init_w(shape=[3, 3, 512, 1024], name='w_9')
			self.b[9] = self.init_b(shape=[1024], name='b_9')
			result_conv_1 = tf.nn.conv2d(
				input=result_dropout, filter=self.w[9],
				strides=[1, 1, 1, 1], padding='SAME', name='conv_1')
			result_relu_1 = tf.nn.relu(tf.nn.bias_add(result_conv_1, self.b[9], name='add_bias'), name='relu_1')

			# conv_2
			self.w[10] = self.init_w(shape=[3, 3, 1024, 1024], name='w_10')
			self.b[10] = self.init_b(shape=[1024], name='b_10')
			result_conv_2 = tf.nn.conv2d(
				input=result_relu_1, filter=self.w[10],
				strides=[1, 1, 1, 1], padding='SAME', name='conv_2')
			result_relu_2 = tf.nn.relu(tf.nn.bias_add(result_conv_2, self.b[10], name='add_bias'), name='relu_2')

			# up sample
			self.w[11] = self.init_w(shape=[2, 2, 512, 1024], name='w_11')
			self.b[11] = self.init_b(shape=[512], name='b_11')
			result_up = tf.nn.conv2d_transpose(
				value=result_relu_2, filter=self.w[11],
				output_shape=[batch_size, 64, 64, 512],
				strides=[1, 2, 2, 1], padding='VALID', name='Up_Sample')
			result_relu_3 = tf.nn.relu(tf.nn.bias_add(result_up, self.b[11], name='add_bias'), name='relu_3')

			# dropout
			result_dropout = tf.nn.dropout(x=result_relu_3, keep_prob=self.keep_prob)

		# layer 6
		with tf.name_scope('layer_6'):
			# copy, crop and merge
			result_merge = self.copy_and_crop_and_merge(
				result_from_contract_layer=self.result_from_contract_layer[4], result_from_upsampling=result_dropout)
			# print(result_merge)

			# conv_1
			self.w[12] = self.init_w(shape=[3, 3, 1024, 512], name='w_12')
			self.b[12] = self.init_b(shape=[512], name='b_12')
			result_conv_1 = tf.nn.conv2d(
				input=result_merge, filter=self.w[12],
				strides=[1, 1, 1, 1], padding='SAME', name='conv_1')
			result_relu_1 = tf.nn.relu(tf.nn.bias_add(result_conv_1, self.b[12], name='add_bias'), name='relu_1')

			# conv_2
			self.w[13] = self.init_w(shape=[3, 3, 512, 512], name='w_10')
			self.b[13] = self.init_b(shape=[512], name='b_10')
			result_conv_2 = tf.nn.conv2d(
				input=result_relu_1, filter=self.w[13],
				strides=[1, 1, 1, 1], padding='SAME', name='conv_2')
			result_relu_2 = tf.nn.relu(tf.nn.bias_add(result_conv_2, self.b[13], name='add_bias'), name='relu_2')
			# print(result_relu_2.shape[1])

			# up sample
			self.w[14] = self.init_w(shape=[2, 2, 256, 512], name='w_11')
			self.b[14] = self.init_b(shape=[256], name='b_11')
			result_up = tf.nn.conv2d_transpose(
				value=result_relu_2, filter=self.w[14],
				output_shape=[batch_size, 128, 128, 256],
				strides=[1, 2, 2, 1], padding='VALID', name='Up_Sample')
			result_relu_3 = tf.nn.relu(tf.nn.bias_add(result_up, self.b[14], name='add_bias'), name='relu_3')

			# dropout
			result_dropout = tf.nn.dropout(x=result_relu_3, keep_prob=self.keep_prob)

		# layer 7
		with tf.name_scope('layer_7'):
			# copy, crop and merge
			result_merge = self.copy_and_crop_and_merge(
				result_from_contract_layer=self.result_from_contract_layer[3], result_from_upsampling=result_dropout)

			# conv_1
			self.w[15] = self.init_w(shape=[3, 3, 512, 256], name='w_12')
			self.b[15] = self.init_b(shape=[256], name='b_12')
			result_conv_1 = tf.nn.conv2d(
				input=result_merge, filter=self.w[15],
				strides=[1, 1, 1, 1], padding='SAME', name='conv_1')
			result_relu_1 = tf.nn.relu(tf.nn.bias_add(result_conv_1, self.b[15], name='add_bias'), name='relu_1')

			# conv_2
			self.w[16] = self.init_w(shape=[3, 3, 256, 256], name='w_10')
			self.b[16] = self.init_b(shape=[256], name='b_10')
			result_conv_2 = tf.nn.conv2d(
				input=result_relu_1, filter=self.w[16],
				strides=[1, 1, 1, 1], padding='SAME', name='conv_2')
			result_relu_2 = tf.nn.relu(tf.nn.bias_add(result_conv_2, self.b[16], name='add_bias'), name='relu_2')

			# up sample
			self.w[17] = self.init_w(shape=[2, 2, 128, 256], name='w_11')
			self.b[17] = self.init_b(shape=[128], name='b_11')
			result_up = tf.nn.conv2d_transpose(
				value=result_relu_2, filter=self.w[17],
				output_shape=[batch_size, 256, 256, 128],
				strides=[1, 2, 2, 1], padding='VALID', name='Up_Sample')
			result_relu_3 = tf.nn.relu(tf.nn.bias_add(result_up, self.b[17], name='add_bias'), name='relu_3')

			# dropout
			result_dropout = tf.nn.dropout(x=result_relu_3, keep_prob=self.keep_prob)

		# layer 8
		with tf.name_scope('layer_8'):
			# copy, crop and merge
			result_merge = self.copy_and_crop_and_merge(
				result_from_contract_layer=self.result_from_contract_layer[2], result_from_upsampling=result_dropout)

			# conv_1
			self.w[18] = self.init_w(shape=[3, 3, 256, 128], name='w_12')
			self.b[18] = self.init_b(shape=[128], name='b_12')
			result_conv_1 = tf.nn.conv2d(
				input=result_merge, filter=self.w[18],
				strides=[1, 1, 1, 1], padding='SAME', name='conv_1')
			result_relu_1 = tf.nn.relu(tf.nn.bias_add(result_conv_1, self.b[18], name='add_bias'), name='relu_1')

			# conv_2
			self.w[19] = self.init_w(shape=[3, 3, 128, 128], name='w_10')
			self.b[19] = self.init_b(shape=[128], name='b_10')
			result_conv_2 = tf.nn.conv2d(
				input=result_relu_1, filter=self.w[19],
				strides=[1, 1, 1, 1], padding='SAME', name='conv_2')
			result_relu_2 = tf.nn.relu(tf.nn.bias_add(result_conv_2, self.b[19], name='add_bias'), name='relu_2')

			# up sample
			self.w[20] = self.init_w(shape=[2, 2, 64, 128], name='w_11')
			self.b[20] = self.init_b(shape=[64], name='b_11')
			result_up = tf.nn.conv2d_transpose(
				value=result_relu_2, filter=self.w[20],
				output_shape=[batch_size, 512, 512, 64],
				strides=[1, 2, 2, 1], padding='VALID', name='Up_Sample')
			result_relu_3 = tf.nn.relu(tf.nn.bias_add(result_up, self.b[20], name='add_bias'), name='relu_3')

			# dropout
			result_dropout = tf.nn.dropout(x=result_relu_3, keep_prob=self.keep_prob)

		# layer 9
		with tf.name_scope('layer_9'):
			# copy, crop and merge
			result_merge = self.copy_and_crop_and_merge(
				result_from_contract_layer=self.result_from_contract_layer[1], result_from_upsampling=result_dropout)

			# conv_1
			self.w[21] = self.init_w(shape=[3, 3, 128, 64], name='w_12')
			self.b[21] = self.init_b(shape=[64], name='b_12')
			result_conv_1 = tf.nn.conv2d(
				input=result_merge, filter=self.w[21],
				strides=[1, 1, 1, 1], padding='SAME', name='conv_1')
			result_relu_1 = tf.nn.relu(tf.nn.bias_add(result_conv_1, self.b[21], name='add_bias'), name='relu_1')

			# conv_2
			self.w[22] = self.init_w(shape=[3, 3, 64, 64], name='w_10')
			self.b[22] = self.init_b(shape=[64], name='b_10')
			result_conv_2 = tf.nn.conv2d(
				input=result_relu_1, filter=self.w[22],
				strides=[1, 1, 1, 1], padding='SAME', name='conv_2')
			result_relu_2 = tf.nn.relu(tf.nn.bias_add(result_conv_2, self.b[22], name='add_bias'), name='relu_2')

			# convolution to [batch_size, OUTPIT_IMG_WIDE, OUTPUT_IMG_HEIGHT, CLASS_NUM]
			self.w[23] = self.init_w(shape=[1, 1, 64, CLASS_NUM], name='w_11')
			self.b[23] = self.init_b(shape=[CLASS_NUM], name='b_11')
			result_conv_3 = tf.nn.conv2d(
				input=result_relu_2, filter=self.w[23],
				strides=[1, 1, 1, 1], padding='VALID', name='conv_3')
			# self.prediction = tf.nn.relu(tf.nn.bias_add(result_conv_3, self.b[23], name='add_bias'), name='relu_3')
			# self.prediction = tf.nn.sigmoid(x=tf.nn.bias_add(result_conv_3, self.b[23], name='add_bias'), name='sigmoid_1')
			self.prediction = tf.nn.bias_add(result_conv_3, self.b[23], name='add_bias')
		# print(self.prediction)
		# print(self.input_label)

		# softmax loss
		with tf.name_scope('softmax_loss'):
			# using one-hot
			# self.loss = \
			# 	tf.nn.softmax_cross_entropy_with_logits(labels=self.cast_label, logits=self.prediction, name='loss')

			# not using one-hot
			self.loss = \
				tf.nn.sparse_softmax_cross_entropy_with_logits(labels=self.input_label, logits=self.prediction, name='loss')
			self.loss_mean = tf.reduce_mean(self.loss)
			tf.add_to_collection(name='loss', value=self.loss_mean)
			self.loss_all = tf.add_n(inputs=tf.get_collection(key='loss'))

		# accuracy
		with tf.name_scope('accuracy'):
			# using one-hot
			# self.correct_prediction = tf.equal(tf.argmax(self.prediction, axis=3), tf.argmax(self.cast_label, axis=3))

			# not using one-hot
			self.correct_prediction = \
				tf.equal(tf.argmax(input=self.prediction, axis=3, output_type=tf.int32), self.input_label)
			self.correct_prediction = tf.cast(self.correct_prediction, tf.float32)
			self.accuracy = tf.reduce_mean(self.correct_prediction)

		# Gradient Descent
		with tf.name_scope('Gradient_Descent'):
			self.train_step = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(self.loss_all)

	def train(self):
		# import cv2
		# import numpy as np
		# ckpt_path = os.path.join(FLAGS.model_dir, "model.ckpt")
		# all_parameters_saver = tf.train.Saver()
		# # import numpy as np
		# # mydata = DataProcess(INPUT_IMG_HEIGHT, INPUT_IMG_WIDE)
		# # imgs_train, imgs_mask_train = mydata.load_my_train_data()
		# my_set_image = cv2.imread('../data_set/train.tif', flags=0)
		# my_set_label = cv2.imread('../data_set/label.tif', flags=0)
		# my_set_image.astype('float32')
		# my_set_label[my_set_label <= 128] = 0
		# my_set_label[my_set_label > 128] = 1
		# my_set_image = np.reshape(a=my_set_image, newshape=(1, INPUT_IMG_WIDE, INPUT_IMG_HEIGHT, INPUT_IMG_CHANNEL))
		# my_set_label = np.reshape(a=my_set_label, newshape=(1, OUTPUT_IMG_WIDE, OUTPUT_IMG_HEIGHT))
		# # cv2.imshow('image', my_set_image)
		# # cv2.imshow('label', my_set_label * 100)
		# # cv2.waitKey(0)
		# with tf.Session() as sess:  # 开始一个会话
		# 	sess.run(tf.global_variables_initializer())
		# 	sess.run(tf.local_variables_initializer())
		# 	for epoch in range(10):
		# 		lo, acc = sess.run(
		# 			[self.loss_mean, self.accuracy],
		# 			feed_dict={
		# 				self.input_image: my_set_image, self.input_label: my_set_label,
		# 				self.keep_prob: 1.0, self.lamb: 0.004}
		# 		)
		# 		# print('num %d, loss: %.6f and accuracy: %.6f' % (epoch, lo, acc))
		# 		print('num %d, loss: %.6f and accuracy: %.6f' % (epoch, lo, acc))
		# 		sess.run(
		# 			[self.train_step],
		# 			feed_dict={
		# 				self.input_image: my_set_image, self.input_label: my_set_label,
		# 				self.keep_prob: 0.6, self.lamb: 0.004}
		# 		)
		# 	all_parameters_saver.save(sess=sess, save_path=ckpt_path)
		# print("Done training")
		train_file_path = os.path.join(FLAGS.data_dir, TRAIN_SET_NAME)
		train_image_filename_queue = tf.train.string_input_producer(
			string_tensor=tf.train.match_filenames_once(train_file_path), num_epochs=EPOCH_NUM, shuffle=True)
		ckpt_path = CHECK_POINT_PATH
		train_images, train_labels = read_image_batch(train_image_filename_queue, TRAIN_BATCH_SIZE)
		tf.summary.scalar("loss", self.loss_mean)
		tf.summary.scalar('accuracy', self.accuracy)
		merged_summary = tf.summary.merge_all()
		all_parameters_saver = tf.train.Saver()
		with tf.Session() as sess:  # 开始一个会话
			sess.run(tf.global_variables_initializer())
			sess.run(tf.local_variables_initializer())
			summary_writer = tf.summary.FileWriter(FLAGS.tb_dir, sess.graph)
			tf.summary.FileWriter(FLAGS.model_dir, sess.graph)
			coord = tf.train.Coordinator()
			threads = tf.train.start_queue_runners(coord=coord)
			try:
				epoch = 1
				while not coord.should_stop():
					# Run training steps or whatever
					# print('epoch ' + str(epoch))
					example, label = sess.run([train_images, train_labels])  # 在会话中取出image和label
					# print(label)
					lo, acc, summary_str = sess.run(
						[self.loss_mean, self.accuracy, merged_summary],
						feed_dict={
							self.input_image: example, self.input_label: label, self.keep_prob: 1.0,
							self.lamb: 0.004}
					)
					summary_writer.add_summary(summary_str, epoch)
					# print('num %d, loss: %.6f and accuracy: %.6f' % (epoch, lo, acc))
					if epoch % 10 == 0:
						print('num %d, loss: %.6f and accuracy: %.6f' % (epoch, lo, acc))
					sess.run(
						[self.train_step],
						feed_dict={
							self.input_image: example, self.input_label: label, self.keep_prob: 0.6,
							self.lamb: 0.004}
					)
					epoch += 1
			except tf.errors.OutOfRangeError:
				print('Done training -- epoch limit reached')
			finally:
				# When done, ask the threads to stop.
				all_parameters_saver.save(sess=sess, save_path=ckpt_path)
				coord.request_stop()
			# coord.request_stop()
			coord.join(threads)
		print("Done training")

	def validate(self):
		# import cv2
		# import numpy as np
		# ckpt_path = os.path.join(FLAGS.model_dir, "model.ckpt")
		# # mydata = DataProcess(INPUT_IMG_HEIGHT, INPUT_IMG_WIDE)
		# # imgs_train, imgs_mask_train = mydata.load_my_train_data()
		# all_parameters_saver = tf.train.Saver()
		# my_set_image = cv2.imread('../data_set/train.tif', flags=0)
		# my_set_label = cv2.imread('../data_set/label.tif', flags=0)
		# my_set_image.astype('float32')
		# my_set_label[my_set_label <= 128] = 0
		# my_set_label[my_set_label > 128] = 1
		# with tf.Session() as sess:
		# 	sess.run(tf.global_variables_initializer())
		# 	sess.run(tf.local_variables_initializer())
		# 	all_parameters_saver.restore(sess=sess, save_path=ckpt_path)
		# 	image, acc = sess.run(
		# 		fetches=[self.prediction, self.accuracy],
		# 		feed_dict={
		# 				self.input_image: my_set_image, self.input_label: my_set_label,
		# 				self.keep_prob: 1.0, self.lamb: 0.004}
		# 	)
		# image = np.argmax(a=image[0], axis=2).astype('uint8') * 255
		# # cv2.imshow('predict', image)
		# # cv2.imshow('o', np.asarray(a=image[0], dtype=np.uint8) * 100)
		# # cv2.waitKey(0)
		# cv2.imwrite(filename=os.path.join(FLAGS.model_dir, 'predict.jpg'), img=image)
		# print(acc)
		# print("Done test, predict image has been saved to %s" % (os.path.join(FLAGS.model_dir, 'predict.jpg')))
		validation_file_path = os.path.join(FLAGS.data_dir, VALIDATION_SET_NAME)
		validation_image_filename_queue = tf.train.string_input_producer(
			string_tensor=tf.train.match_filenames_once(validation_file_path), num_epochs=1, shuffle=True)
		ckpt_path = CHECK_POINT_PATH
		validation_images, validation_labels = read_image_batch(validation_image_filename_queue, VALIDATION_BATCH_SIZE)
		# tf.summary.scalar("loss", self.loss_mean)
		# tf.summary.scalar('accuracy', self.accuracy)
		# merged_summary = tf.summary.merge_all()
		all_parameters_saver = tf.train.Saver()
		with tf.Session() as sess:  # 开始一个会话
			sess.run(tf.global_variables_initializer())
			sess.run(tf.local_variables_initializer())
			# summary_writer = tf.summary.FileWriter(FLAGS.tb_dir, sess.graph)
			# tf.summary.FileWriter(FLAGS.model_dir, sess.graph)
			all_parameters_saver.restore(sess=sess, save_path=ckpt_path)
			coord = tf.train.Coordinator()
			threads = tf.train.start_queue_runners(coord=coord)
			try:
				epoch = 1
				while not coord.should_stop():
					# Run training steps or whatever
					# print('epoch ' + str(epoch))
					example, label = sess.run([validation_images, validation_labels])  # 在会话中取出image和label
					# print(label)
					lo, acc = sess.run(
						[self.loss_mean, self.accuracy],
						feed_dict={
							self.input_image: example, self.input_label: label, self.keep_prob: 1.0,
							self.lamb: 0.004}
					)
					# summary_writer.add_summary(summary_str, epoch)
					# print('num %d, loss: %.6f and accuracy: %.6f' % (epoch, lo, acc))
					if epoch % 1 == 0:
						print('num %d, loss: %.6f and accuracy: %.6f' % (epoch, lo, acc))
					epoch += 1
			except tf.errors.OutOfRangeError:
				print('Done validating -- epoch limit reached')
			finally:
				# When done, ask the threads to stop.
				coord.request_stop()
			# coord.request_stop()
			coord.join(threads)
		print('Done validating')

	def test(self):
		import cv2
		test_file_path = os.path.join(FLAGS.data_dir, TEST_SET_NAME)
		test_image_filename_queue = tf.train.string_input_producer(
			string_tensor=tf.train.match_filenames_once(test_file_path), num_epochs=1, shuffle=True)
		ckpt_path = CHECK_POINT_PATH
		test_images, test_labels = read_image_batch(test_image_filename_queue, TEST_BATCH_SIZE)
		# tf.summary.scalar("loss", self.loss_mean)
		# tf.summary.scalar('accuracy', self.accuracy)
		# merged_summary = tf.summary.merge_all()
		all_parameters_saver = tf.train.Saver()
		with tf.Session() as sess:  # 开始一个会话
			sess.run(tf.global_variables_initializer())
			sess.run(tf.local_variables_initializer())
			# summary_writer = tf.summary.FileWriter(FLAGS.tb_dir, sess.graph)
			# tf.summary.FileWriter(FLAGS.model_dir, sess.graph)
			all_parameters_saver.restore(sess=sess, save_path=ckpt_path)
			coord = tf.train.Coordinator()
			threads = tf.train.start_queue_runners(coord=coord)
			sum_acc = 0.0
			try:
				epoch = 0
				while not coord.should_stop():
					# Run training steps or whatever
					# print('epoch ' + str(epoch))
					example, label = sess.run([test_images, test_labels])  # 在会话中取出image和label
					# print(label)
					image, acc = sess.run(
						[tf.argmax(input=self.prediction, axis=3), self.accuracy],
						feed_dict={
							self.input_image: example, self.input_label: label,
							self.keep_prob: 1.0, self.lamb: 0.004
						}
					)
					sum_acc += acc
					epoch += 1
					cv2.imwrite(os.path.join(PREDICT_SAVED_DIRECTORY, '%d.jpg' % epoch), image[0] * 255)
					if epoch % 1 == 0:
						print('num %d accuracy: %.6f' % (epoch, acc))
			except tf.errors.OutOfRangeError:
				print('Done testing -- epoch limit reached \n Average accuracy: %.2f%%' % (sum_acc / TEST_SET_SIZE * 100))
			finally:
				# When done, ask the threads to stop.
				coord.request_stop()
			# coord.request_stop()
			coord.join(threads)
		print('Done testing')

	def predict(self):
		import cv2
		import glob
		import numpy as np
		# TODO 不应该这样写,应该直接读图片预测,而不是从tfrecord读取,因为顺序变了,无法对应
		predict_file_path = glob.glob(os.path.join(ORIGIN_PREDICT_DIRECTORY, '*.tif'))
		print(len(predict_file_path))
		if not os.path.lexists(PREDICT_SAVED_DIRECTORY):
			os.mkdir(PREDICT_SAVED_DIRECTORY)
		ckpt_path = CHECK_POINT_PATH
		all_parameters_saver = tf.train.Saver()
		with tf.Session() as sess:  # 开始一个会话
			sess.run(tf.global_variables_initializer())
			sess.run(tf.local_variables_initializer())
			# summary_writer = tf.summary.FileWriter(FLAGS.tb_dir, sess.graph)
			# tf.summary.FileWriter(FLAGS.model_dir, sess.graph)
			all_parameters_saver.restore(sess=sess, save_path=ckpt_path)
			for index, image_path in enumerate(predict_file_path):
				# image = cv2.imread(image_path, flags=0)
				image = np.reshape(a=cv2.imread(image_path, flags=0), newshape=(1, INPUT_IMG_WIDE, INPUT_IMG_HEIGHT, INPUT_IMG_CHANNEL))
				predict_image = sess.run(
					tf.argmax(input=self.prediction, axis=3),
					feed_dict={
						self.input_image: image,
						self.keep_prob: 1.0, self.lamb: 0.004
					}
				)
				cv2.imwrite(os.path.join(PREDICT_SAVED_DIRECTORY, '%d.jpg' % index), predict_image[0] * 255)
		print('Done prediction')

 

  • 9
    点赞
  • 104
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
差分模型(diffusion model)是一种基于偏微分方程的算法,用于图像处理、计算机视觉等领域。在差分模型中,U-Net结构是一种常见的神经网络架构,用于处理图像分割问题。那么为什么差分模型钟爱U-Net结构呢? 首先,U-Net结构是一种全卷积神经网络,能够有效地处理不同尺度的特征。在图像分割中,需要同时考虑图像的全局和局部信息,而U-Net结构能够通过嵌套的卷积和池化操作,提取不同尺度的特征。此外,U-Net结构还具有跳跃连接(skip connections)的特点,能够将浅层和深层特征进行合并,进一步提高图像分割的准确性。 其次,U-Net结构还可以进行端到端的训练,避免了手工特征提取的繁琐过程。在差分模型中,U-Net结构可以和差分算法相结合,实现对图像的全局和局部信息的有效传播和融合。通过不断迭代更新,可以得到更准确的图像分割结果。 最后,U-Net结构还具有可扩展性和灵活性的特点。它可以使用不同的优化器、损失函数和正则化方法进行训练和调节,以适应不同的应用场景。此外,还可以通过增加网络深度和宽度等方式进一步提升网络性能。 综上所述,U-Net结构在差分模型中的应用得到了广泛的认可和应用。在今后的研究中,我们有信心通过不断探索和创新,进一步提高U-Net结构在图像分割和差分模型中的性能和实用性。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值