tensorflow基础模板

import tensorflow as tf
import numpy as np
import os

#hyper parameters
batch_size = 32
save_path = 'model'
max_train_step = 10000
lr = 0.001

#network structure
class nn(object):
    def __init__(self, name='nn', trainning=True, reuse=False):
        with tf.variable_scope(name, reuse=reuse):
            self.global_step = tf.Variable(0, trainable=False, dtype=tf.int32)
    def summary(self):
        pass

#placehoder

#model
train_model = nn()

#opt
train_up = tf.train.AdamOptimizer(lr).minimize(train_model.loss, train_model.global_step)

#save
saver = tf.train.Saver()

with tf.Session() as sess:
    #recoder, summary
    train_model.summary()
    train_writer = tf.summary.FileWriter('log', graph=sess.graph)
    merged = tf.summary.merge_all()

    #restore or initail
    ckpt = tf.train.get_checkpoint_state(save_path)
    if ckpt:
        ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
        saver.restore(sess, os.path.join(save_path, ckpt_name))
    else:
        sess.run(tf.global_variables_initializer())

    #circulation start
    global_step_val = sess.run(train_model.global_step)
    while global_step_val < max_train_step:
        sess.run(train_up, feed_dict={})
        global_step_val += 1
        if global_step_val % 100 == 0:
            saver.save(sess, os.path.join(save_path, 'nn.ckpt'), global_step_val)
            merged_summary = sess.run(merged, feed_dict={})
            train_writer.add_summary(merged_summary, global_step_val)


  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值