今天在做自定义全连接层遇到了这个问题,数据类型与权重参数矩阵不匹配。
因为手写数字识别是二值图像,加上这个x_train = x_train/255.0,就好了
import keras
from tensorflow.keras.datasets import mnist
from tensorflow.keras import Model, layers
import tensorflow as tf
import numpy as np
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = np.expand_dims(x_train, axis=-1)
x_test = np.expand_dims(x_test, axis=-1)
y_train = tf.keras.utils.to_categorical(y_train, 10)
y_test = tf.keras.utils.to_categorical(y_test, 10)
x_train = x_train/255.0
class My_Dnn_Layers(layers.Layer):
def __init__(self, out):
super(My_Dnn_Layers, self).__init__()
self.out = out
self.w = None
def build(self, input_shape):
self.w = self.add_weight(name="w", shape=(int(input_shape[-1]), self.out), dtype="float32")
def call(self, inputs, *args, **kwargs):
return tf.matmul(inputs, self.w)
class MyDNN(keras.Model):
def __init__(self):
super(MyDNN, self).__init__()
self.Fatten = layers.Flatten(input_shape=(28, 28))
self.dense = My_Dnn_Layers(10)
def call(self, inputs, training=None, mask=None):
x = self.Fatten(inputs)
x = self.dense(x)
return tf.nn.softmax(x)
net = MyDNN()
net.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
net.fit(x_train, y_train, epochs=10)