代码如下
from torchvision import datasets
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear,Sequential
import torch
# from tensorboardX import SummaryWriter
# 搭建神经网络
class Easy_Cnn(nn.Module):
def __init__(self):
super(Easy_Cnn, self).__init__() # 对父类进行重写
self.model = Sequential(
Conv2d(3, 32, 5, stride=1, padding=2), # 卷积层
MaxPool2d(2), # 池化层
Conv2d(32, 32, 5, stride=1, padding=2),
MaxPool2d(2),
Conv2d(32, 64, 5, stride=1, padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024, 64, bias=True), # 线性模型
Linear(64, 10, bias=True)
)
def forward(self, x):
x = self.model(x)
return x
# 加载数据,并进行训练
input = torch.ones(64, 3, 32, 32)
easy_cnn = Easy_Cnn()
output = easy_cnn(input)
print(output.shape)
print(output)
# 可视化
# writer = SummaryWriter("../logs_seq")
# writer.add_graph(easy_cnn,input)
# writer.close()