池化层的作用是为了缩减参数。nn.MaxPool2d的使用比较简单,其中参数ceil_model为False表示当输入的一部分小于池化核时则舍弃。
# 池化的作用是为了缩减参数
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
class My_nn(nn.Module):
def __init__(self):
super(My_nn, self).__init__()
self.maxpool = nn.MaxPool2d(kernel_size=3, ceil_mode=False) # ceil_model为False表示当输入的一部分小于池化核时舍弃
def forward(self, input):
output = self.maxpool(input)
return output
if __name__ == '__main__':
# 采用这个池化神经网络对CIFAR10的图片进行处理
# 为了减少数据集个数,对测试集进行处理
data_set = torchvision.datasets.CIFAR10('dataset', train=False, transform=torchvision.transforms.ToTensor(),
download=True)
data_loader = DataLoader(data_set, batch_size=64, drop_last=True)
# 初始化类
my_nn = My_nn()
writer = SummaryWriter('nn_maxpool')
step = 0
for data in data_loader:
imgs, targets = data
writer.add_images('input', imgs, step)
out_imgs = my_nn(imgs)
writer.add_images('output', out_imgs, step)
step += 1
writer.close()