实验平台:colab。
该分类网络结构:首先用 3D卷积,然后使用 2D卷积。
数据集:Indian_pines(16个种类)。
评价指标:loss,precision,recall,f1-score,support
网络模型:
class HybridSN(nn.Module):
def __init__(self):
super(HybridSN,self).__init__()
self.conv3d1=nn.Conv3d(1,8,kernel_size=(7,3,3),stride=1,padding=0)
self.bn1=nn.BatchNorm3d(8)
self.conv3d2=nn.Conv3d(8,16,kernel_size=(5,3,3),stride=1,padding=0)
self.bn2=nn.BatchNorm3d(16)
self.conv3d3=nn.Conv3d(16,32,kernel_size=(3,3,3),stride=1,padding=0)
self.bn3=nn.BatchNorm3d(32)
self.conv2d1=nn.Conv2d(576,64,kernel_size=(3,3),stride=1,padding=0)
self.bn4=nn.BatchNorm2d(64)
self.fc1=nn.Linear(18496,256)
self.fc2=nn.Linear(256,128)
self.fc3=nn.Linear(128,16)
self.dropout=nn.Dropout(0.4)
def forward(self,x):
out=F.relu(self.bn1(self.conv3d1(x)))
out=F.relu(self.bn2(self.conv3d2(out)))
out=F.relu(self.bn3(self.conv3d3(out)))
out = F.relu(self.bn4(self.conv2d1(out.reshape(out.shape[0],-1,19,19))))
out = out.reshape(out.shape[0],-1)
out = F.relu(self.dropout(self.fc1(out)))
out = F.relu(self.dropout(self.fc2(out)))
out = self.fc3(out)
return out
# 随机输入,测试网络结构是否通
x = torch.randn(1, 1, 30, 25, 25)
net = HybridSN()
y = net(x)
print(y.shape)