TensorFlow上分之路
pycharm代码结构
今天开始了我的tensorflow上分之路,之前一直都是一个giter,老是去剽窃别人在github上的代码,有一天觉得这不是长久之计,所以准备开始试着自己去学习写代码,去研究别人的风格。于是乎找了一个个人认为还算绝美的风格格式,以这个格式为模板开始写自己的代码,写在这上面作为自己的备忘录,并且之后所有相关的算法代码都会保持这个风格。
代码整体结构如下
checkpoint
这里是tensorflow每个epoch训练好的模型权重
data
用于存储数据,这里我们简单点,就先以mnist数据集为开始
common
这里将是我编写算法中常用的基本函数的地方,主要包括卷积,全连接这样的函数,今天第一天先简单一点
import tensorflow as tf
def weight_and_bias(input_data,name,trainable=True,bn=True,activate=True):
with tf.variable_scope(name):
weights=tf.Variable(tf.zeros([784,10]))
biases=tf.Variable(tf.zeros([1,10]))
predict = tf.nn.softmax(tf.matmul(input_data, weights) + biases)
return predict
config
这里是做文件基本配置的地方,关系到整体算法中的所有参数配置
from easydict import EasyDict as edict
__C=edict()
cfg=__C
__C.model=edict()
__C.model.the_train_data='./data/train-images-idx3-ubyte.gz'
__C.model.the_train_label='./data/train-labels-idx1-ubyte.gz'
__C.model.the_train_batch_size=100
__C.model.the_test_data='./data/t10k-images-idx3-ubyte.gz'
__C.model.the_test_label='./data/t10k-labels-idx1-ubyte.gz'
__C.model.the_test_batch_size=20
datalist
这里是我准备数据集的地方,一开始学的时候没看懂,看了好久突然明白了 iter 和 next的关系。
这里主要负责在train的时候输出我规定好的每个batch数据集以及他们的标签
import warnings
warnings.filterwarnings('ignore')
from yolov3_tf.config import cfg
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import tensorflow as tf
class Dataset(object):
def __init__(self,data_type):
# self.data_path=cfg.model.the_train_data if data_type=="train" else cfg.model.the_test_data
self.batch_size=cfg.model.the_train_batch_size if data_type=="train" else cfg.model.the_test_batch_size
self.batch_count=0
self.load_data=input_data.read_data_sets('data',one_hot=True)
self.num_samples=self.load_data.train.num_examples if data_type=="train" else self.load_data.test.num_examples
self.batch_num=int(np.ceil(self.num_samples//self.batch_size))
def __iter__(self):
return self
def __next__(self):
with tf.device('/cpu:0'):
batch_images=np.zeros((self.batch_size,784))
batch_labels=np.zeros((self.batch_size,10))
num=0
if self.batch_count<self.batch_num:
while num<self.batch_size:
index=self.batch_count*self.batch_size+num
image=self.load_data.train.images[index]
label=self.load_data.train.labels[index]
batch_images[num,:]=np.array(image)
batch_labels[num,:]=label
num+=1
self.batch_count+=1
#按批次传输 images 和 labels
return batch_images,batch_labels
else:
self.batch_count=0
raise StopIteration
def __len__(self):
return self.batch_num
model
这里是算法的重头戏之一,建立模型。作为这个系列的第一集,算法也不去像github上的大佬们那样写的很复杂的模型,因为我们跑的是mnist数据集,最基础的一个数据集,而且我在common中也就写了一个weight_and_bias的函数,所以这次的模型就弄得很简单
import tensorflow as tf
from yolov3_tf.common import weight_and_bias
class model(object):
def __init__(self,the_input_data,trainable):
self.trainable=trainable
self.model_input_data=the_input_data
try:
self.predict=self.__build_network()
except:
raise NotImplementedError("Can not build up the network")
def __build_network(self):
predict=weight_and_bias(self.model_input_data,trainable=self.trainable,name="weight_and_bias")
return predict
def compute_loss(self,input_label):
loss=tf.reduce_mean(-tf.reduce_sum(input_label * tf.log(self.predict), axis=1), axis=0)
return loss
train
这使我们训练的重头戏,为训练模型,以及保存模型,这里我就直接上代码,主要还是根据test的performance来保存模型
import warnings
warnings.filterwarnings('ignore')
import tensorflow as tf
from yolov3_tf.model import model
import os
import shutil
from tqdm import tqdm
from yolov3_tf.datalist import Dataset
import numpy as np
class Train(object):
def __init__(self):
self.total_epoches=200
self.trainset=Dataset("train")
self.testset=Dataset("test")
# 如果你指定的设备不存在,允许TF自动分配设备
self.sess=tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
with tf.name_scope("define_input"):
self.input_data = tf.placeholder(dtype=tf.float32,name="input_data")
self.input_label = tf.placeholder(dtype=tf.float32,name="input_label")
self.trainable = tf.placeholder(dtype=tf.bool,name="training")
with tf.name_scope("define_loss"):
self.model=model(self.input_data,self.trainable)
self.net_var=tf.global_variables()
self.loss=self.model.compute_loss(self.input_label)
with tf.name_scope("learn_rate"):
self.global_step=tf.Variable(1.0,dtype=tf.float64,trainable=False,name="global_step")
self.learn_rate=0.00001
gloabel_step_update=tf.assign_add(self.global_step,1.0)
with tf.name_scope("define_train"):
train_stage_trainable_var_list=tf.trainable_variables()
train_stage_optimizer=tf.train.AdamOptimizer(self.learn_rate).minimize(self.loss,var_list=train_stage_trainable_var_list)
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
with tf.control_dependencies([train_stage_optimizer,gloabel_step_update]):
self.train_stage_op_with_all_variables=tf.no_op()
with tf.name_scope("loader_and_saver"):
self.loader=tf.train.Saver(self.net_var)
self.saver=tf.train.Saver(tf.global_variables(),max_to_keep=10)
with tf.name_scope("summary"):
tf.summary.scalar("learn_rate",self.learn_rate)
tf.summary.scalar("loss",self.loss)
logdir="./data/log/"
if os.path.exists(logdir):
shutil.rmtree(logdir)
os.mkdir(logdir)
self.write_op=tf.summary.merge_all()
self.summary_writer=tf.summary.FileWriter(logdir,graph=self.sess.graph)
def train(self):
print("=> Now it starts to train the model from stratch ...")
self.sess.run(tf.global_variables_initializer())
for epoch in range(1,1+self.total_epoches):
train_op=self.train_stage_op_with_all_variables
pbar=tqdm(self.trainset)
train_epoch_loss,test_epoch_loss=[],[]
for train_data in pbar:
_,summary,train_step_loss,global_step_val=self.sess.run(
[train_op,self.write_op,self.loss,self.global_step],
feed_dict={
self.trainable: True,
self.input_data: train_data[0],
self.input_label: train_data[1],
}
)
train_epoch_loss.append(train_step_loss)
self.summary_writer.add_summary(summary,global_step_val)
pbar.set_description("train loss:%.2f" %train_step_loss)
for test_data in self.testset:
test_step_loss=self.sess.run(
self.loss,
feed_dict={
self.input_data:test_data[0],
self.input_label:test_data[1],
self.trainable:False,
}
)
test_epoch_loss.append(test_step_loss)
train_epoch_loss,test_epoch_loss=np.mean(train_epoch_loss),np.mean(test_epoch_loss)
ckpt_file="./checkpoint/model_test_loss=%.4f.ckpt" %test_epoch_loss
print("=> Epoch: %2d Train loss: %.2f Test loss: %.2f Saving %s "
%(epoch,train_epoch_loss,test_epoch_loss,ckpt_file))
self.saver.save(self.sess,ckpt_file,global_step=epoch)
if __name__=='__main__':
print(Train().train())
注
以上的代码风格,我是在看一篇yolov3-tf版本中看到的,感觉他的代码写的极度优美就忍不住抄下来反复的看。
学习算法本身就是一件很难的事,而github上风格千千万万,如果能有一个好的风格模板并持之以恒相信会对学习有很大的帮助
[1]: https://github.com/YunYang1994/tensorflow-yolov3