TensorFlow上分之路

pycharm代码结构

今天开始了我的tensorflow上分之路,之前一直都是一个giter,老是去剽窃别人在github上的代码,有一天觉得这不是长久之计,所以准备开始试着自己去学习写代码,去研究别人的风格。于是乎找了一个个人认为还算绝美的风格格式,以这个格式为模板开始写自己的代码,写在这上面作为自己的备忘录,并且之后所有相关的算法代码都会保持这个风格。
代码整体结构如下
在这里插入图片描述

checkpoint

这里是tensorflow每个epoch训练好的模型权重

在这里插入图片描述

data

用于存储数据,这里我们简单点,就先以mnist数据集为开始
在这里插入图片描述

common

这里将是我编写算法中常用的基本函数的地方,主要包括卷积,全连接这样的函数,今天第一天先简单一点

import tensorflow as tf

def weight_and_bias(input_data,name,trainable=True,bn=True,activate=True):
    with tf.variable_scope(name):
        weights=tf.Variable(tf.zeros([784,10]))
        biases=tf.Variable(tf.zeros([1,10]))
        predict = tf.nn.softmax(tf.matmul(input_data, weights) + biases)
    return predict

config

这里是做文件基本配置的地方,关系到整体算法中的所有参数配置

from easydict import EasyDict as edict

__C=edict()


cfg=__C

__C.model=edict()

__C.model.the_train_data='./data/train-images-idx3-ubyte.gz'
__C.model.the_train_label='./data/train-labels-idx1-ubyte.gz'
__C.model.the_train_batch_size=100

__C.model.the_test_data='./data/t10k-images-idx3-ubyte.gz'
__C.model.the_test_label='./data/t10k-labels-idx1-ubyte.gz'
__C.model.the_test_batch_size=20

datalist

这里是我准备数据集的地方,一开始学的时候没看懂,看了好久突然明白了 iter 和 next的关系。
这里主要负责在train的时候输出我规定好的每个batch数据集以及他们的标签

import warnings
warnings.filterwarnings('ignore')
from yolov3_tf.config import cfg
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import tensorflow as tf


class Dataset(object):

    def __init__(self,data_type):
        # self.data_path=cfg.model.the_train_data if data_type=="train" else cfg.model.the_test_data
        self.batch_size=cfg.model.the_train_batch_size if data_type=="train" else cfg.model.the_test_batch_size
        self.batch_count=0
        self.load_data=input_data.read_data_sets('data',one_hot=True)
        self.num_samples=self.load_data.train.num_examples if data_type=="train" else self.load_data.test.num_examples
        self.batch_num=int(np.ceil(self.num_samples//self.batch_size))

    def __iter__(self):
        return self

    def __next__(self):

        with tf.device('/cpu:0'):

            batch_images=np.zeros((self.batch_size,784))
            batch_labels=np.zeros((self.batch_size,10))

            num=0
            if self.batch_count<self.batch_num:
                while num<self.batch_size:
                    index=self.batch_count*self.batch_size+num
                    image=self.load_data.train.images[index]
                    label=self.load_data.train.labels[index]
                    batch_images[num,:]=np.array(image)
                    batch_labels[num,:]=label
                    num+=1
                self.batch_count+=1
                #按批次传输 images 和 labels
                return batch_images,batch_labels
            else:
                self.batch_count=0
                raise StopIteration

    def __len__(self):
        return self.batch_num

model

这里是算法的重头戏之一,建立模型。作为这个系列的第一集,算法也不去像github上的大佬们那样写的很复杂的模型,因为我们跑的是mnist数据集,最基础的一个数据集,而且我在common中也就写了一个weight_and_bias的函数,所以这次的模型就弄得很简单

import tensorflow as tf
from yolov3_tf.common import weight_and_bias

class model(object):
    def __init__(self,the_input_data,trainable):
        self.trainable=trainable
        self.model_input_data=the_input_data


        try:
            self.predict=self.__build_network()
        except:
            raise NotImplementedError("Can not build up the network")

    def __build_network(self):
        predict=weight_and_bias(self.model_input_data,trainable=self.trainable,name="weight_and_bias")
        return predict
    def compute_loss(self,input_label):
        loss=tf.reduce_mean(-tf.reduce_sum(input_label * tf.log(self.predict), axis=1), axis=0)
        return loss

train

这使我们训练的重头戏,为训练模型,以及保存模型,这里我就直接上代码,主要还是根据test的performance来保存模型

import warnings
warnings.filterwarnings('ignore')
import tensorflow as tf
from yolov3_tf.model import model
import os
import shutil
from tqdm import tqdm
from yolov3_tf.datalist import Dataset
import numpy as np


class Train(object):
    def __init__(self):


        self.total_epoches=200
        self.trainset=Dataset("train")
        self.testset=Dataset("test")
        # 如果你指定的设备不存在,允许TF自动分配设备
        self.sess=tf.Session(config=tf.ConfigProto(allow_soft_placement=True))


        with tf.name_scope("define_input"):
            self.input_data  = tf.placeholder(dtype=tf.float32,name="input_data")
            self.input_label = tf.placeholder(dtype=tf.float32,name="input_label")
            self.trainable   = tf.placeholder(dtype=tf.bool,name="training")

        with tf.name_scope("define_loss"):
            self.model=model(self.input_data,self.trainable)
            self.net_var=tf.global_variables()
            self.loss=self.model.compute_loss(self.input_label)

        with tf.name_scope("learn_rate"):
            self.global_step=tf.Variable(1.0,dtype=tf.float64,trainable=False,name="global_step")
            self.learn_rate=0.00001
            gloabel_step_update=tf.assign_add(self.global_step,1.0)

        with tf.name_scope("define_train"):
            train_stage_trainable_var_list=tf.trainable_variables()
            train_stage_optimizer=tf.train.AdamOptimizer(self.learn_rate).minimize(self.loss,var_list=train_stage_trainable_var_list)

            with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
                with tf.control_dependencies([train_stage_optimizer,gloabel_step_update]):
                    self.train_stage_op_with_all_variables=tf.no_op()

        with tf.name_scope("loader_and_saver"):
            self.loader=tf.train.Saver(self.net_var)
            self.saver=tf.train.Saver(tf.global_variables(),max_to_keep=10)

        with tf.name_scope("summary"):
            tf.summary.scalar("learn_rate",self.learn_rate)
            tf.summary.scalar("loss",self.loss)

            logdir="./data/log/"

            if os.path.exists(logdir):
                shutil.rmtree(logdir)
            os.mkdir(logdir)

            self.write_op=tf.summary.merge_all()
            self.summary_writer=tf.summary.FileWriter(logdir,graph=self.sess.graph)

    def train(self):
        print("=> Now it starts to train the model from stratch ...")
        self.sess.run(tf.global_variables_initializer())

        for epoch in range(1,1+self.total_epoches):
            train_op=self.train_stage_op_with_all_variables

            pbar=tqdm(self.trainset)
            train_epoch_loss,test_epoch_loss=[],[]

            for train_data in pbar:
                _,summary,train_step_loss,global_step_val=self.sess.run(
                    [train_op,self.write_op,self.loss,self.global_step],
                    feed_dict={
                        self.trainable: True,
                        self.input_data: train_data[0],
                        self.input_label: train_data[1],
                    }
                )

            train_epoch_loss.append(train_step_loss)
            self.summary_writer.add_summary(summary,global_step_val)
            pbar.set_description("train loss:%.2f" %train_step_loss)

            for test_data in self.testset:
                test_step_loss=self.sess.run(
                    self.loss,
                    feed_dict={
                        self.input_data:test_data[0],
                        self.input_label:test_data[1],
                        self.trainable:False,
                    }
                )
                test_epoch_loss.append(test_step_loss)

            train_epoch_loss,test_epoch_loss=np.mean(train_epoch_loss),np.mean(test_epoch_loss)
            ckpt_file="./checkpoint/model_test_loss=%.4f.ckpt" %test_epoch_loss
            print("=> Epoch: %2d Train loss: %.2f Test loss: %.2f Saving %s "
                  %(epoch,train_epoch_loss,test_epoch_loss,ckpt_file))
            self.saver.save(self.sess,ckpt_file,global_step=epoch)


if __name__=='__main__':
    print(Train().train())

以上的代码风格,我是在看一篇yolov3-tf版本中看到的,感觉他的代码写的极度优美就忍不住抄下来反复的看。
学习算法本身就是一件很难的事,而github上风格千千万万,如果能有一个好的风格模板并持之以恒相信会对学习有很大的帮助
[1]: https://github.com/YunYang1994/tensorflow-yolov3

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值