import tensorflow as tf
from tensorflow.keras import datasets,layers,Sequential,optimizers
from tensorflow import keras
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
def preprocess(x,y):
# print(x.shape) #(32, 32, 3)
# x = tf.reshape(x, [-1, 32 * 32 * 3])
# print(x.shape) #(1, 3072)
# y = tf.squeeze(y)
# y = tf.one_hot(y , depth=10 )
x = tf.cast(x , dtype=tf.float32) / 255.
y = tf.cast(y , dtype=tf.int32)
return x,y
class MyLayer(layers.Layer):
def __init__(self , inp_dim , outp_dim):
super(MyLayer, self).__init__()
self.kernel = self.add_variable('w' , [inp_dim,outp_dim])
# self.bias = self.add_variable('b' , [outp_dim])
def call(self, inputs, training = None):
x = inputs@self.kernel
return x
class MyNetwork(tf.keras.Model): #注意别继承错
def __init__(self):
super(MyNetwork, self).__init__()
self.fc1 = MyLayer(32*32*3,256)
self.fc2 = MyLayer(256,256)
self.fc3 = MyLayer(256,64)
self.fc4 = MyLayer(64,32)
self.fc5 = MyLayer(32,10)
def call(self, inputs, training=None, mask=None):
x = self.fc1(inputs)
x = tf.nn.relu(x)
x = self.fc2(x)
x = tf.nn.relu(x)
x = self.fc3(x)
x = tf.nn.relu(x)
x = self.fc4(x)
x = tf.nn.relu(x)
logits = self.fc5(x)
return logits
batchzs = 128
def main():
(x,y),(x_test,y_test) =datasets.cifar10.load_data()
print(x.shape,y.shape, x.max() , x.min() ,y.max() , y.min()) #(50000, 32, 32, 3) (50000, 1)
x = tf.reshape(x , (-1,32*32*3))
x_test = tf.reshape(x_test , (-1,32*32*3))
y = tf.one_hot(tf.squeeze(y) ,depth=10)
y_test = tf.one_hot(tf.squeeze(y_test),depth=10)
print(x.shape , y.shape )
db = tf.data.Dataset.from_tensor_slices((x,y))
db = db.map(preprocess).shuffle(500000).batch(batch_size=batchzs)
db_test = tf.data.Dataset.from_tensor_slices((x_test,y_test))
db_test = db_test.map(preprocess).batch(batch_size=batchzs)
samp = next(iter(db))
print( samp[0].shape , samp[1].shape )
network = MyNetwork()
network.compile(optimizer=optimizers.Adam(lr=1e-3) , loss=tf.losses.CategoricalCrossentropy(from_logits=True),metrics=['accuracy'])
network.fit(db , epochs=15 , verbose=2, validation_data=db_test,validation_freq=1)
network.save_weights('./mydense0317001.ckpt')
print('save weights.')
del network
network2 = MyNetwork()
network2.compile(optimizer=optimizers.Adam(lr=1e-3) , loss=tf.losses.CategoricalCrossentropy(from_logits=True),metrics=['accuracy'])
network2.load_weights('./mydense0317001.ckpt')
print('load weights from file.')
network2.evaluate(db_test)
if __name__ == '__main__':
main()