import tensorflow as tf
"""
(1)构造函数__init__参数
input_sz: 输入层placeholder的4-D shape,如mnist是[None,28,28,1]
(2)train函数:训练一步
batch_input: 输入的batch
batch_output: label
learning_rate:学习率
返回:正确率和loss值(float) 格式:{"accuracy":accuracy,"loss":loss}
(3)forward:训练后用于测试
(4)save(save_path,steps)保存模型
(5)restore(path):从文件夹中读取最后一个模型
(6)loss函数使用cross-entrop one-hot版本:y*log(y_net)
(7)optimizer使用adamoptimier
"""
class ResNet:
sess=None
#Tensor
input=None
output=None
desired_out=None
loss=None
iscorrect=None
accuracy=None
optimizer=None
param_num=0 #参数个数
#参数
learning_rate=None
MOMENTUM = 0.9
WEIGHT_DECAY = 1e-4 #L2 REGULARIZATION
ACTIVATE = None
CONV_PADDING = "SAME"
MAX_POOL_PADDING = "SAME"
CONV_WEIGHT_INITAILIZER = tf.keras.initializers.he_normal()#tf.truncated_normal_initializer(stddev=0.1)
CONV_BIAS_INITAILIZER = tf.constant_initializer(value=0.0)
FC_WEIGHT_INITAILIZER = tf.keras.initializers.he_normal()#tf.truncated_normal_initializer(stddev=0.1)
FC_BIAS_INITAILIZER = tf.constant_initializer(value=0.0)
def train(self,batch_input,batch_output,learning_rate):
_,accuracy,loss=self.sess.run([self.optimizer,self.accuracy,self.loss],
feed_dict={self.input:batch_input,self.desired_out:batch_output,self.learning_rate:learning_rate})
return {"accuracy":accuracy,"loss":loss}
def forward(self,batch_input):
return self.sess.run(self.output,feed_dict={self.input:batch_input})
def save(self,save_path,steps):
saver=tf.train.Saver(max_to_keep=5)
saver.save(self.sess,save_path,global_step=steps)
print("[*]save success")
def restore(self,restore_path):
path=tf.train.latest_checkpoint(restore_path)
if path==No
[深度学习] ResNet实现细节
最新推荐文章于 2024-06-23 16:51:28 发布
import tensorflow as tf"""(1)构造函数__init__参数 input_sz: 输入层placeholder的4-D shape,如mnist是[None,28,28,1] (2)train函数:训练一步 batch_input: 输入的batch batch_output: label learning_rate:学习率 返回:正...
摘要由CSDN通过智能技术生成