我们如何将多张图片的特征融合到一块,这里提供了一种处理通道为1的医学图像方法。
假如我们要将3张图像融合到一起。在输入的时候,我们可以将三张图片合成为3通道输入,然后在三个通道上,分别用神经网络进行卷积,然后将卷积的图像,在全连接层,进行concat,然后进全连接层输出分类数目。
import torch
import torch.nn as nn
import torchvision.models as models
class CustomResNet50(nn.Module):
def __init__(self, in_channels=1,num_classes=2, chunk=3):
super(CustomResNet50, self).__init__()
self.chunk = chunk
# 加载预训练的ResNet-50模型
resnet = models.resnet50(pretrained=False)
# 将修改后的模型赋值给自定义的ResNet-50网络
self.model = resnet
# 修改第一个卷积层的输入通道数
self.model.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
# 修改全连接层的输出特征数
self.fc = nn.Linear(2048 * self.chunk, num_classes)
def cat(self,x):
x = self.model.conv1(x)
x = self.model.bn1(x)
x = self.model.relu(x)
x = self.model.maxpool(x)
x = self.model.layer1(x)
x = self.model.layer2(x)
x = self.model.layer3(x)
x = self.model.layer4(x)
x = self.model.avgpool(x)
x = torch.flatten(x, 1)
return x
def forward(self, x ):
# 对输入的三通道图像进行分割
self.data = []
split_data = torch.split(x, 1, dim=1)
for i in range(self.chunk):
self.data.append(self.cat(split_data[i]))
# 在全连接层之前进行拼接
x = torch.cat(self.data, dim=1)
# 全连接层输出
output = self.fc(x)
return output
我们需要先实例化,in_channels为1通道的原始图像,nu_classes为2类别数。输入的图像为(B,N,H,W),chunk为合成的几维数据,这里是将三张图像融合,所以是三维。
# 创建一个实例,并指定类别数
net = CustomResNet50( in_channels=1,num_classes=2,chunk=1)
# 输出网络结构
print(net)
# 在输入数据上进行前向传播
input_data = torch.randn(64, 1, 224, 224) # 假设输入数据尺寸为1x224x224
output = net(input_data)
print("前连接层特征尺寸:", output.size())
这样就只训练一个网络就可以了,之前是训练了两个网络来融合,不好操作。
这里只是融合1通道的医学图像,如果为3通道的彩色图像,需要先将图片的维度合并,然后将输入的维度进行3个3个拆分后卷积拼接,等用到的时候再改代码吧。