##NIN
import torch
from torch import nn
from d2l import torch as d2l
def nn_block(in_channel,out_channel,kernal,strides,padding):
return nn.Sequential(nn.Conv2d(in_channel,out_channel,kernal,strides,padding),nn.ReLU(),
nn.Conv2d(out_channel,out_channel,kernel_size=1),nn.ReLU(),
nn.Conv2d(out_channel,out_channel,kernel_size=1),nn.ReLU())
net = nn.Sequential(nn_block(1,96,kernal=11,strides=4,padding=0),
nn.MaxPool2d(3,stride=2),
nn_block(96,256,kernal=5,strides=1,padding=2),
nn.MaxPool2d(3,stride=2),
nn_block(256,384,kernal=3,strides=1,padding=1),
nn.MaxPool2d(3,stride=2),nn.Dropout(0.5),
nn_block(384,10,kernal=3,strides=1,padding=1),
nn.AdaptiveAvgPool2d((1,1)),nn.Flatten())
#查看每个块的形状
x = torch.rand(size=(1,1,224,224))
for layer in net:
x=layer(x)
print(layer.__class__.__name__,"out shape:\\t",x.shape)
xuexilv,xunlianxishu,btach = 0.1,10,128
train_iter,test_iter = d2l.load_data_fashion_mnist(btach,resize=224)
d2l.train_ch6(net,train_iter,test_iter,xunlianxishu,xuexilv,d2l.try_gpu())
d2l.plt.show()