pytorch学习记录02——多分支网络

本文使用pytorch框架来搭建一个多分支的神经网络,编程时借鉴了Inception的编程思想。
准备实现的网络结构如下图所示

在这里插入图片描述(图画的有点丑)

每个分支的输入图片大小设置为64*64,卷积层和池化层的参数设置如下表所示

参数
CONV13,16,kernel_size=3, stride=1, padding=1
Pooling1kernel_size=2, stride=2
CONV216,32,kernel_size=3, stride=1, padding=1
Pooling2kernel_size=2, stride=2
CONV332,64,kernel_size=3, stride=1, padding=1
Pooling3kernel_size=2, stride=2
CONV464,128,kernel_size=3, stride=1, padding=1
Pooling4kernel_size=2, stride=2

首先,导入需要的库

# 导入库
import torch
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch import nn, optim

接下来,定义网络模型

# 定义模型结构
class ThreeInputsNet(nn.Module):
    def __init__(self):
        super(ThreeInputsNet, self).__init__()
        # 3, 64, 64
        self.conv1_1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.pooling1_1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv1_2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.conv1_3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv1_4 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        # 3, 64, 64
        self.conv2_1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.pooling2_1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2_2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.conv2_3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv2_4 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        # 3, 64, 64
        self.conv3_1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.pooling3_1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv3_2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.conv3_3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv3_4 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        # 128, 4, 4
        # 三个通道的channel合并
        # 128*5, 4, 4
        self.outlayer1 = nn.Linear(3 * 128 * 4 * 4, 128 * 3)
        self.outlayer2 = nn.Linear(128 * 3, 256)
        self.outlayer3 = nn.Linear(256, 3)
	# 此处的输入为三个,对应三个分支
    def forward(self, input1, input2, input3):
        out1 = self.pooling1_1(self.conv1_1(input1))
        out1 = self.pooling1_1(self.conv1_2(out1))
        out1 = self.pooling1_1(self.conv1_3(out1))
        out1 = self.pooling1_1(self.conv1_4(out1))

        out2 = self.pooling2_1(self.conv2_1(input2))
        out2 = self.pooling2_1(self.conv2_2(out2))
        out2 = self.pooling2_1(self.conv2_3(out2))
        out2 = self.pooling2_1(self.conv2_4(out2))

        out3 = self.pooling3_1(self.conv3_1(input3))
        out3 = self.pooling3_1(self.conv3_2(out3))
        out3 = self.pooling3_1(self.conv3_3(out3))
        out3 = self.pooling3_1(self.conv3_4(out3))
		# 将三个分支的结果在channel维度上合并
        out = torch.cat((out1, out2, out3), dim=1)
        out = out.view(out.size(0), -1)  # [B, C, H, W] --> [B, C*H*W]
        out = self.outlayer1(out)
        out = self.outlayer2(out)
        out = self.outlayer3(out)
        return out

输入一些数据测试一下网络能否跑通

if __name__ == '__main__':
    input1 = torch.ones(8, 3, 64, 64)
    input2 = torch.ones(8, 3, 64, 64)
    input3 = torch.ones(8, 3, 64, 64)
    net = ThreeInputsNet()
    output = net(input1, input2, input3)
    print("out.shape:{}".format(output.shape))

完整代码


# 导入库
import torch
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch import nn, optim


# 定义模型结构
class ThreeInputsNet(nn.Module):
    def __init__(self):
        super(ThreeInputsNet, self).__init__()
       # 3, 64, 64
        self.conv1_1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.pooling1_1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv1_2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.conv1_3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv1_4 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        # 3, 64, 64
        self.conv2_1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.pooling2_1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2_2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.conv2_3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv2_4 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        # 3, 64, 64
        self.conv3_1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.pooling3_1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv3_2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.conv3_3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv3_4 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        # 128, 4, 4
        # 三个个通道的channel合并
        # 128*5, 4, 4
        self.outlayer1 = nn.Linear(3 * 128 * 4 * 4, 128 * 5)
        self.outlayer2 = nn.Linear(128 * 5, 256)
        self.outlayer3 = nn.Linear(256, 3)

    def forward(self, input1, input2, input3):
        out1 = self.pooling1_1(self.conv1_1(input1))
        out1 = self.pooling1_1(self.conv1_2(out1))
        out1 = self.pooling1_1(self.conv1_3(out1))
        out1 = self.pooling1_1(self.conv1_4(out1))

        out2 = self.pooling2_1(self.conv2_1(input2))
        out2 = self.pooling2_1(self.conv2_2(out2))
        out2 = self.pooling2_1(self.conv2_3(out2))
        out2 = self.pooling2_1(self.conv2_4(out2))

        out3 = self.pooling3_1(self.conv3_1(input3))
        out3 = self.pooling3_1(self.conv3_2(out3))
        out3 = self.pooling3_1(self.conv3_3(out3))
        out3 = self.pooling3_1(self.conv3_4(out3))
        # 将三个分支的结果在channel维度上合并
        out = torch.cat((out1, out2, out3), dim=1)
        out = out.view(out.size(0), -1)  # [B, C, H, W] --> [B, C*H*W]
        out = self.outlayer1(out)
        out = self.outlayer2(out)
        out = self.outlayer3(out)
        return out


if __name__ == '__main__':
    input1 = torch.ones(8,3,64,64)
    input2 = torch.ones(8, 3, 64, 64)
    input3 = torch.ones(8, 3, 64, 64)
    net = ThreeInputsNet()
    output = net(input1, input2, input3)
    print("out.shape:{}".format(output.shape))
  • 4
    点赞
  • 52
    收藏
    觉得还不错? 一键收藏
  • 10
    评论
Pytorch是机器学习中的一个重要框架,它与TensorFlow一起被认为是机器学习的两大框架。Pytorch学习可以从以下几个方面入手: 1. Pytorch基本语法:了解Pytorch的基本语法和操作,包括张量(Tensors)的创建、导入torch库、基本运算等\[2\]。 2. Pytorch中的autograd:了解autograd的概念和使用方法,它是Pytorch中用于自动计算梯度的工具,可以方便地进行反向传播\[2\]。 3. 使用Pytorch构建一个神经网络学习使用torch.nn库构建神经网络的典型流程,包括定义网络结构、损失函数、反向传播和更新网络参数等\[2\]。 4. 使用Pytorch构建一个分类器:了解如何使用Pytorch构建一个分类器,包括任务和数据介绍、训练分类器的步骤以及在GPU上进行训练等\[2\]。 5. Pytorch的安装:可以通过pip命令安装Pytorch,具体命令为"pip install torch torchvision torchaudio",这样就可以在Python环境中使用Pytorch了\[3\]。 以上是一些关于Pytorch学习的笔记,希望对你有帮助。如果你需要更详细的学习资料,可以参考引用\[1\]中提到的网上帖子,或者查阅Pytorch官方文档。 #### 引用[.reference_title] - *1* [pytorch自学笔记](https://blog.csdn.net/qq_41597915/article/details/123415393)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^control_2,239^v3^insert_chatgpt"}} ] [.reference_item] - *2* *3* [Pytorch学习笔记](https://blog.csdn.net/pizm123/article/details/126748381)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^control_2,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 10
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值