前言
我在github搜了一个mnist分类代码,基于TensorFlow框架的,效果挺不错,但是我只想用pytorch,所以打算转成pytorch代码。
过程
TensorFlow网络:
def model_net(num):
model = Sequential()
model.add(Conv2D(filters=16, kernel_size=(5, 5), padding='Same',
activation='relu', input_shape=(28, 28, 3)))
model.add(BatchNormalization())
model.add(Conv2D(filters=16, kernel_size=(3, 3), padding='Same',
activation='relu'))
model.add(BatchNormalization())
model.add(Conv2D(filters=32, kernel_size=(3, 3), padding='Same',
activation='relu'))
model.add(BatchNormalization())
model.add(MaxPool2D(pool_size=(2, 2), strides=(2, 2)))
model.add(Flatten())
model.add(Dense(256, activation="relu"))
model.add(Dropout(0.5))

文章介绍了如何将一个基于TensorFlow的MNIST手写数字分类模型转换为PyTorch实现。作者首先展示了TensorFlow模型的结构,然后逐步构建了等效的PyTorch网络,并使用torchsummary库来输出PyTorch模型的结构。虽然参数略有差异,但两个模型的运行结果相似。
最低0.47元/天 解锁文章
963

被折叠的 条评论
为什么被折叠?



