在containers中,可以将多个nn中的模块集成:
import torch
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Linear, Flatten
class Tudui(nn.Module):
def __init__(self):
super().__init__()
self.module1 = Sequential(
Conv2d(3,32,5,padding=2), # 通过官方文件的公式推导padding、dilation都是多少
MaxPool2d(2),
Conv2d(32,32,5,padding=2),
MaxPool2d(2),
Conv2d(32,64,5,padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024,64),
Linear(64, 10)
)
def forward(self,x):
x = self.module1(x)
return(x)
tudui=Tudui()
input = torch.ones((64,3,32,32)) # 创建一个都是1的数据
output = tudui(input)
print(output.shape)