TensorFlow神经网络自定义方法、存储、读取

import tensorflow as tf
from tensorflow.keras import datasets,layers,Sequential,optimizers
from tensorflow import keras
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'

def preprocess(x,y):
    # print(x.shape)    #(32, 32, 3)
    # x = tf.reshape(x, [-1, 32 * 32 * 3])
    # print(x.shape)    #(1, 3072)
    # y = tf.squeeze(y)
    # y = tf.one_hot(y , depth=10 )
    x = tf.cast(x , dtype=tf.float32) / 255.
    y = tf.cast(y , dtype=tf.int32)
    return x,y

class MyLayer(layers.Layer):
    def __init__(self , inp_dim , outp_dim):
        super(MyLayer, self).__init__()
        self.kernel = self.add_variable('w' , [inp_dim,outp_dim])
        # self.bias = self.add_variable('b' , [outp_dim])
    def call(self, inputs, training = None):
        x = inputs@self.kernel
        return x
class MyNetwork(tf.keras.Model): #注意别继承错
    def __init__(self):
        super(MyNetwork, self).__init__()
        self.fc1 = MyLayer(32*32*3,256)
        self.fc2 = MyLayer(256,256)
        self.fc3 = MyLayer(256,64)
        self.fc4 = MyLayer(64,32)
        self.fc5 = MyLayer(32,10)
    def call(self, inputs, training=None, mask=None):
        x = self.fc1(inputs)
        x = tf.nn.relu(x)
        x = self.fc2(x)
        x = tf.nn.relu(x)
        x = self.fc3(x)
        x = tf.nn.relu(x)
        x = self.fc4(x)
        x = tf.nn.relu(x)
        logits = self.fc5(x)
        return logits

batchzs = 128
def main():
    (x,y),(x_test,y_test) =datasets.cifar10.load_data()
    print(x.shape,y.shape, x.max() , x.min() ,y.max() , y.min())  #(50000, 32, 32, 3) (50000, 1)
    x = tf.reshape(x , (-1,32*32*3))
    x_test = tf.reshape(x_test , (-1,32*32*3))
    y = tf.one_hot(tf.squeeze(y) ,depth=10)
    y_test = tf.one_hot(tf.squeeze(y_test),depth=10)
    print(x.shape , y.shape )
    db = tf.data.Dataset.from_tensor_slices((x,y))
    db = db.map(preprocess).shuffle(500000).batch(batch_size=batchzs)
    db_test = tf.data.Dataset.from_tensor_slices((x_test,y_test))
    db_test = db_test.map(preprocess).batch(batch_size=batchzs)

    samp = next(iter(db))
    print( samp[0].shape , samp[1].shape )

    network = MyNetwork()
    network.compile(optimizer=optimizers.Adam(lr=1e-3) , loss=tf.losses.CategoricalCrossentropy(from_logits=True),metrics=['accuracy'])
    network.fit(db , epochs=15 , verbose=2, validation_data=db_test,validation_freq=1)

    network.save_weights('./mydense0317001.ckpt')
    print('save weights.')
    del network
    network2 = MyNetwork()
    network2.compile(optimizer=optimizers.Adam(lr=1e-3) , loss=tf.losses.CategoricalCrossentropy(from_logits=True),metrics=['accuracy'])
    network2.load_weights('./mydense0317001.ckpt')
    print('load weights from file.')
    network2.evaluate(db_test)

if __name__ == '__main__':
    main()

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

weixin_39540983

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值