利用ResNet的思想构建神经网络训练cifar-10

#coding=utf-8
import tensorflow as tf
import os
import pickle
import numpy as np


CIFAR_DIR='./cifar-10-batches-py'
print(os.listdir(CIFAR_DIR))


def residual_block(x,output_channel):
	input_channel=x.get_shape().as_list()[-1]
	if output_channel==input_channel*2:
		increase_dim=True
		strides=(2,2)
	elif input_channel==output_channel:
		increase_dim=False
		strides=(1,1)
	else:
		raise Exception("input_channel don't match output_channel")
	conv1=tf.layers.conv2d(x,
							output_channel,
							(3,3),
							strides=strides,
							padding='SAME',
							activation=tf.nn.relu,
							name='conv1')
	conv2=tf.layers.conv2d(conv1,
							output_channel,
							(3,3),
							strides=(1,1),
							padding="SAME",
							activation=tf.nn.relu,
							name='conv2')
	if increase_dim:
		#[None,width,height,channel]
		pooled_x=tf.layers.average_pooling2d(x,
											(2,2),
											(2,2),
											padding="VALID")
		paded_x=tf.pad(pooled_x,
						[[0,0],
						[0,0],
						[0,0],
						[input_channel//2,input_channel//2]])
	else:
		paded_x=x

	output_x=conv2+paded_x
	return output_x
def resNet(x,num_blocks,num_filter_base,class_num):
	"""resNet implemention"""
	num_sampling=len(num_blocks)
	layers=[]
	#[None,width,height,channel]
	input_size=x.get_shape().as_list()[1:]
	with tf.variable_scope('conv0'):
		conv0=tf.layers.conv2d(x,
			num_filter_base,
			(3,3),
			strides=(1,1),
			padding="SAME",activation=tf.nn.relu,
			name="conv0")
		layers.append(conv0)
	for sample_id in range(num_sampling):
		for i in range(num_blocks[sample_id]):
			with tf.variable_scope("conv%d_%d"%(sample_id,i)):
				conv=residual_block(
					layers[-1],
					num_filter_base*(2**sample_id))
				layers.append(conv)
	multiplier_output=2**(num_sampling-1)
	assert layers[-1].get_shape().as_list()[1:]==\
		[input_size[0]/multiplier_output,
		input_size[1]/multiplier_output,
		num_filter_base*multiplier_output]

	with tf.variable_scope('fc'):
		#layer[-1].shape:[None,width .height,channel]
		global_pool=tf.reduce_mean(layers[-1],[1,2])

		logits=tf.layers.dense(global_pool,class_num,activation=tf.nn.softmax)
		layers.append(logits)
	return layers[-1]

def load_data(filename):
	#read data
	with open(filename,'rb') as f:
		data=pickle.load(f,encoding='bytes')
		#print(data.keys())
		return data[b'data'],data[b'labels']



class CifarData(object):
	"""docstring for CifarData"""
	def __init__(self, filename,shuffle):
		all_data=[]
		all_labels=[]
		all_data,all_labels=load_data(filename)
			
				
		self._size=len(all_labels)
		self._data=np.vstack(all_data/127.5-1)
		self._label=np.hstack(all_labels)
		self._num_examples=self._data.shape[0]
		self._need_shuffle=shuffle
		self._indicator=0
		if self._need_shuffle:
			self._shuffle_data()
	def _shuffle_data(self):
		p=np.random.permutation(self._num_examples)#shuffle
		self._data=self._data[p]
		self._label=self._label[p]

	def next_batch(self,batch_size):
		end_indicator=self._indicator+batch_size
		if end_indicator>self._num_examples:
			if self._need_shuffle:
				self._shuffle_data()
				self._indicator=0
				#end_indicator=batch_size

			else:
				raise Exception("no more examples")
		else:
			pass

		batch_data=self._data[self._indicator:self._indicator+batch_size]
		batch_label=self._label[self._indicator:self._indicator+batch_size]
		self._indicator=end_indicator
		return batch_data,batch_label

train_filenames=[os.path.join(CIFAR_DIR,'data_batch_%d'% i)for i in range(1,6)]
test_filenames=os.path.join(CIFAR_DIR,'test_batch')

#train_data=CifarData(train_filenames,True)

test_data=CifarData(test_filenames,True)

x=tf.placeholder(tf.float32,[None,3072])
x_image=tf.reshape(x,[-1,3,32,32])
x_image=tf.transpose(x_image,perm=[0,2,3,1])
y=tf.placeholder(tf.int64,[None])
#add code

#diaoyong Resnet
p_y=resNet(x_image,[2,3,2],32,10)


y_one_hot = tf.one_hot(y,10,dtype=tf.float32)

loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_one_hot,logits=p_y))
loss_summary=tf.summary.scalar("loss",loss)
'''
p_y_1=tf.nn.sigmoid(y_)

loss=tf.reduce_mean(tf.square(y_reshape-p_y_1))

predict=p_y_1>0.5
'''
predict=tf.argmax(p_y,1)
correct_prediction=tf.equal(tf.cast(predict,tf.int32),tf.cast(y,tf.int32))
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
accuracy_t=tf.cast(correct_prediction,tf.float32)
accuracy_summary=tf.summary.scalar("accuracy",accuracy)
with tf.name_scope('train_op'):
	train_op=tf.train.AdamOptimizer(0.001).minimize(loss)


init=tf.global_variables_initializer()
batch_size=128
train_epochs=3000
accuracy_test=0.0
output_summary_everystep=5
merged_summary=tf.summary.merge_all()

LOG_DIR='.'
run_file="ResNet_visibal"
LOG_DIR_file=os.path.join(LOG_DIR,run_file)
if not os.path.exists(LOG_DIR_file):
	os.mkdir(LOG_DIR_file)
train_log="train"
test_log="test"
train_log_run=os.path.join(LOG_DIR_file,train_log)
test_log_run=os.path.join(LOG_DIR_file,test_log)
if not os.path.exists(train_log_run):
	os.mkdir(train_log_run)
if not os.path.exists(test_log_run):
	os.mkdir(test_log_run)
with tf.Session() as sess:
	sess.run(init)
	train_log_writer=tf.summary.FileWriter(train_log_run,sess.graph)
	test_log_writer=tf.summary.FileWriter(test_log_run)
	for i in range(train_epochs):
		run_list=[loss,accuracy,train_op]
		for train_filename in train_filenames:
			train_data=CifarData(train_filename,True)
			batch_data,batch_label=train_data.next_batch(batch_size)
			if i%output_summary_everystep==0:
				run_list.append(merged_summary)
			run_result=sess.run(run_list,feed_dict={x:batch_data,y:batch_label})
			loss_val,acc_val=run_result[0:2]
		if i%output_summary_everystep==0:
			train_summary=run_result[-1]
			print("type")
			print(type(train_summary))
			train_log_writer.add_summary (train_summary,i+1)

		#if i%20==0:
			
	
			for j in range(7):

				batch_test_data,batch_test_label=test_data.next_batch(test_data._size//7)
				acc_test,_,test_summary=sess.run([accuracy,loss,merged_summary],feed_dict={x:batch_test_data,y:batch_test_label})
			
			test_log_writer.add_summary(test_summary,i+1) 
			print('[Train] step:%d, loss: %4.5f, acc:%4.5f,acc_test:%f'%(i,loss_val,acc_val,acc_test))
			#print('acc:%f'% accuracy_test)
	train_log_writer.close()
	test_log_writer.close()

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值