最近,帮老婆debug过程中,也学习了一些tensorflow的相关知识。其中,花了比较多的实验研究如何使用tf.dataset来导入数据。以下是对一个demo的小记录,完成的工作是在tensorflow上实现AlexNet,使用的数据为MNIST。
由于国内网络原因,直接用tf中的函数读取mnist数据集较慢,我选择使用keras.datasets中import mnist数据集。这就导致后续的训练中,无法直接使用mnist.train.next_batch(batch_size)来生成mnist的训练数据。因此,选择使用tf.dataset将读取的numpy数据转化为tf所需格式。
读取数据代码如下:
from keras.datssets import mnist
from keras.utils import to_categorical
def load_mnist(image_size):
(x_train,y_train),(x_test,y_test) = mnist.load_data()
train_image = [cv2.cvtColor(cv2.resize(img,(image_size,image_size)),cv2.COLOR_GRAY2BGR) for img in x_train]
test_image = [cv2.cvtColor(cv2.resize(img,(image_size,image_size)),cv2.COLOR_GRAY2BGR) for img in x_test]
train_image = np.asarray(train_image, 'f')
test_image = np.asarray(test_image, 'f')
train_label = to_categorical(y_train)
test_label = to_categorical(y_test)
print('finish loading data!')
return train_image, train_label, test_image, test_label
load_mnist函数完成了mnist数据的读取,并将单通道转化为3通道(使用opencv函数),将label转化为onehot编码。
接下来,进行网络的初始化。网络结构转载自老婆的博客。
首先定义几个层结构:
# define layers
def conv(x, kernel, strides, b):
return tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(x, kernel, strides, padding = 'SAME'), b))
def max_pooling(x, kernel, strides):
return tf.nn.max_pool(x, kernel, strides, padding = 'VALID')
def fc(x, w, b):
return tf.nn.relu(tf.add(tf.matmul(x,w),b))
其次,定义网络初始权值:
# define variables
weights = {
'wc1':tf.Variable(tf.random_normal([11,