conv1d
输入数据格式为(batch_size, channel, length)
nn.conv1d(in_channels, outchannel, kernel_size, stride=1,padding=0,dilation=1,groups=1)
nn.Conv1d(1, 20, 5)表示输入1通道,20个卷积核,核大小为5
conv2d
输入数据格式为(batch_size, channel, Height, Width),
nn.conv2d(in_channels,out_channels,kernel_size, stride=(1,1),padding=0,dilation=(1,1),groups=1)
nn.Conv2d(1, 20, (3, 3), stride=(1, 1),padding=(2,2)) 表示输入1通道,20个卷积核,核大小为(3*3)
针对conv2d , 输入的是4维,[150,103,7,7]
conv3d
输入数据格式为(batch_size, channel, Depth, Height, Width)
nn.conv2d(in_channels,out_channels,kernel_size, stride=(1,1,1),padding=0,dilation=(1,1,1),groups=1)
nn.Conv3d(1, 90, (24, 3, 3), stride=(9,1,1),padding=(1,1,1))
输入的是5维,150是batch_size,1是增加的维度,103是brand波段,7*7是patch_size.
FC
nn.Linear(input_channels, outpout_channels)
要求数据输入格式为(batch_size, length),
150是batch_size, 103是波段, 没有patch_size,以单个像素作为输入。
总结
在本地的实验中,当patch>1时,data = data.unsqueeze(0), 增加一维,变成四维数据(Planes x Channels x Width x Height),Planes=1, channels=波段。dataloader之后,变成五维数据,((Batch x) Planes x Channels x Width x Height)。因此nn.Conv3d(1, 90, (24, 3, 3)),输入通道是1.
当输入网络是conv2d时,拿到五维数据,((Batch x) Planes x Channels x Width x Height),其中Planes=1。此时需要x = x.squeeze(). 把Planes去掉,变成((Batch x) Channels x Width x Height)然后nn.Conv2d(input_channels, 80, (3, 3)),其中input_channels就是波段数量。