修改torchvision models中的efficientnet的输入图像通道数(改为4)
from torchsummary import summary
import torchvision.models as models
import torch.nn as nn
model = models.efficientnet_b2()
old_conv1 = model.features[0][0]
new_conv1 = nn.Conv2d(
in_channels=old_conv1.in_channels + 1, # 改成适合自己任务的通道数,此处通道数为 3+1=4
out_channels=old_conv1.out_channels,
kernel_size=old_conv1.kernel_size,
stride=old_conv1.stride,
padding=old_conv1.padding,
bias=True if old_conv1.bias else False,
)
new_conv1.weight[:, :old_conv1.in_channels, :, :].data.copy_(old_conv1.weight.clone())
model.features[0][0] = new_conv1
print(model)
summary(model, input_size=(4, 256, 256), batch_size=-1, device='cpu')
![在这里插入图片描述](https://img-blog.csdnimg.cn/ec187a48204e4561afb4f7ec9c9dc16c.png)