pytorch学习记录03-多分支网络的完整训练代码

本文接上一篇文章“pytorch学习记录02——多分支网络”。把完整的代码展示出来供大家借鉴。
网络的结构图和里面的参数在上一篇文章已经说过了,这里就直接放代码了。

首先,导入相关的库

# 导入库
import random
import torch
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch import nn, optim

然后,定义网络模型的结构

# 定义模型结构
class ThreeInputsNet(nn.Module):
    def __init__(self):
        super(ThreeInputsNet, self).__init__()
        # 3, 64, 64
        self.conv1_1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.pooling1_1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv1_2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.conv1_3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv1_4 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        # 3, 64, 64
        self.conv2_1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.pooling2_1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2_2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.conv2_3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv2_4 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        # 3, 64, 64
        self.conv3_1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.pooling3_1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv3_2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.conv3_3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv3_4 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        # 128, 4, 4
        # 三个通道的channel合并
        # 128*5, 4, 4
        self.outlayer1 = nn.Linear(3 * 128 * 4 * 4, 128 * 5)
        self.outlayer2 = nn.Linear(128 * 5, 256)
        self.outlayer3 = nn.Linear(256, 5)  # 是几分类第二个数就改成几,比如我们做5分类的任务,这里就是5

    # 此处的输入为三个,对应三个分支
    def forward(self, input1, input2, input3):
        out1 = self.pooling1_1(self.conv1_1(input1))
        out1 = self.pooling1_1(self.conv1_2(out1))
        out1 = self.pooling1_1(self.conv1_3(out1))
        out1 = self.pooling1_1(self.conv1_4(out1))

        out2 = self.pooling2_1(self.conv2_1(input2))
        out2 = self.pooling2_1(self.conv2_2(out2))
        out2 = self.pooling2_1(self.conv2_3(out2))
        out2 = self.pooling2_1(self.conv2_4(out2))

        out3 = self.pooling3_1(self.conv3_1(input3))
        out3 = self.pooling3_1(self.conv3_2(out3))
        out3 = self.pooling3_1(self.conv3_3(out3))
        out3 = self.pooling3_1(self.conv3_4(out3))
        # 将三个分支的结果在channel维度上合并
        out = torch.cat((out1, out2, out3), dim=1)
        out = out.view(out.size(0), -1)  # [B, C, H, W] --> [B, C*H*W]
        out = self.outlayer1(out)
        out = self.outlayer2(out)
        out = self.outlayer3(out)
        return out

然后,定义好训练过程

def main():
    device = DEVICE
    model = ThreeInputsNet().to(device)
    loss_fuc = nn.CrossEntropyLoss().to(device)  # 设置损失函数为交叉熵函数
    optimizer = optim.Adam(model.parameters(), lr=LR)  # 设置优化器, weight_decay=0.1


    # 训练模型
    top_acc = 0.0
    for epoch in range(EPOCHS):
        train_loss = 0
        train_acc = 0
        model.train()  # 声明为train模式
        # 使用for循环同时遍历三个dataloader
        for (num, (input_L, label_L)), (num1, (input_R, label_R)), (num2, (input_M, label_M)) in zip(
                enumerate(train_loader_L, start=1),
                enumerate(train_loader_R, start=1),
                enumerate(train_loader_M, start=1)):
            input_L, label_L = input_L.to(device), label_L.to(device)
            input_R, label_R = input_R.to(device), label_R.to(device)
            input_M, label_M = input_M.to(device), label_M.to(device)
            # 比较三个label值是否相等(如果三个通道的输入数据要求有对应关系的话,就在这比较一下)
            assert torch.equal(label_L, label_R) and torch.equal(label_R, label_M), "训练集标签不同"

            y_ = model(input_L, input_R, input_M)
            loss = loss_fuc(y_, label_M)  # 计算损失
            optimizer.zero_grad()  # 梯度清零
            loss.backward()  # 损失回传
            optimizer.step()
            # 记录误差
            train_loss += loss.item()
            # 计算分类的准确率
            out_t = y_.argmax(dim=1)  # 取出预测的最大值
            num_correct = (out_t == label_M).sum().item()
            acc = num_correct / input_M.shape[0]
            train_acc += acc

            # 打印训练过程
            rate = (num1 + 1) / len(train_loader_M)
            a = "*" * int(rate * 50)
            b = '.' * int((1 - rate) * 50)
            print("\rtrain loss:{:^3.0f}%[{}->{}]{:.4f}".format(int(rate * 100), a, b, loss), end="")

        print("\nEpoch:", epoch + 1, 'train_loss:', train_loss / len(train_loader_M), " train_acc:",
              train_acc / len(train_loader_M))

        # 测试模型
        model.eval()  # 声明为test模式
        with torch.no_grad():  # with这一段不需要构建计算图
            # test
            total_correct = 0  # 正确的数量
            total_num = 0
            for (num1, (x1, label1)), (num2, (x2, label2)), (num3, (x3, label3)) in zip(
                    enumerate(test_loader_L, start=1),
                    enumerate(test_loader_R, start=1),
                    enumerate(test_loader_M, start=1)):
                x1, label1 = x1.to(device), label1.to(device)
                x2, label2 = x2.to(device), label2.to(device)
                x3, label3 = x3.to(device), label3.to(device)
                # 判断label是否相等
                assert torch.equal(label1, label2) and torch.equal(label2, label3), "测试标签不同"

                y_ = model(x1, x2, x3)
                pred = y_.argmax(dim=1)  # 选出最大值的索引作为预测的分类结果
                correct = torch.eq(pred, label1).float().sum().item()  # 如果预测值和label值相等则正确数量加一
                total_correct += correct
                total_num += x1.size(0)

            acc = total_correct / total_num

            if acc >= top_acc:
                top_acc = acc
            print("Epoch:", epoch + 1, '; test_top_acc:', top_acc * 100.0, "; test_acc:", acc * 100.0)


最后,设置超参数并将数据送入网络

if __name__ == '__main__':
    # 超参数
    BATCH_SIZE = 8
    EPOCHS = 200
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # DEVICE = 'cpu'
    LR = 1e-3
    print(DEVICE)

    # 准备数据
    # 数据预处理
    train_transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
    ])

    test_transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
    ])

    # 读取数据
    train_dataset_L = datasets.ImageFolder('自己的路径', train_transform)
    train_dataset_R = datasets.ImageFolder('自己的路径', train_transform)
    train_dataset_M = datasets.ImageFolder('自己的路径', train_transform)

    test_dataset_L = datasets.ImageFolder('自己的路径', test_transform)
    test_dataset_R = datasets.ImageFolder('自己的路径', test_transform)
    test_dataset_M = datasets.ImageFolder('自己的路径', test_transform)
    print(train_dataset_L.class_to_idx)
    print(test_dataset_L.class_to_idx)

    # 导入数据
    seed = random.randint(0, 100)  # 设置随机种子,用于打乱数据
    g = torch.Generator()
    g.manual_seed(seed)  # 如果不加这句,每次启动程序后,随机的结果都是一样的
    train_loader_L = torch.utils.data.DataLoader(train_dataset_L, batch_size=BATCH_SIZE, shuffle=True, generator=g)
    g = torch.Generator()
    g.manual_seed(seed)
    train_loader_R = torch.utils.data.DataLoader(train_dataset_R, batch_size=BATCH_SIZE, shuffle=True, generator=g)
    g = torch.Generator()
    g.manual_seed(seed)
    train_loader_M = torch.utils.data.DataLoader(train_dataset_M, batch_size=BATCH_SIZE, shuffle=True, generator=g)

    # 对于训练集。测试一下是否打乱了数据,以及打乱之后三个通道的标签是否能对应上
    # for (num1, (input_L, label_L)), (num2, (input_R, label_R)), (num3, (input_M, label_M)),\
    #         in zip(enumerate(train_loader_L, start=1),
    #               enumerate(train_loader_R, start=1),
    #               enumerate(train_loader_M, start=1)):
    #
    #     print(num1)
    #     print("label_L:{}".format(label_L))
    #     print("label_R:{}".format(label_R))
    #     print("label_M: {}".format(label_M))
    #     if num1 == 5:
    #         break

    g2 = torch.Generator()
    g2.manual_seed(seed)
    test_loader_L = torch.utils.data.DataLoader(test_dataset_L, batch_size=BATCH_SIZE, shuffle=True, generator=g2)
    g2 = torch.Generator()
    g2.manual_seed(seed)
    test_loader_R = torch.utils.data.DataLoader(test_dataset_R, batch_size=BATCH_SIZE, shuffle=True, generator=g2)
    g2 = torch.Generator()
    g2.manual_seed(seed)
    test_loader_M = torch.utils.data.DataLoader(test_dataset_M, batch_size=BATCH_SIZE, shuffle=True, generator=g2)

    # 对于测试集。测试一下是否打乱了数据,以及打乱之后三个通道的标签是否能对应上
    # for (num1, (x1, label1)), (num2, (x2, label2)), (num3, (x3, label3)),\
    #         in zip(enumerate(test_loader_L, start=1),
    #                enumerate(test_loader_R, start=1),
    #                enumerate(test_loader_M, start=1),):
    #     print(num1)
    #     print("test_l1:{}".format(label1))
    #     print("test_l2:{}".format(label2))
    #     print("test_l3:{}".format(label3))
    #     if num1 == 5:
    #         break
    # 开始训练
    main()

便于大家复制,这里给出整体的代码

# 导入库
import torch
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch import nn, optim
import random


# 定义模型结构
class ThreeInputsNet(nn.Module):
    def __init__(self):
        super(ThreeInputsNet, self).__init__()
        # 3, 64, 64
        self.conv1_1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.pooling1_1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv1_2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.conv1_3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv1_4 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        # 3, 64, 64
        self.conv2_1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.pooling2_1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2_2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.conv2_3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv2_4 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        # 3, 64, 64
        self.conv3_1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.pooling3_1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv3_2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.conv3_3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv3_4 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        # 128, 4, 4
        # 三个通道的channel合并
        # 128*5, 4, 4
        self.outlayer1 = nn.Linear(3 * 128 * 4 * 4, 128 * 5)
        self.outlayer2 = nn.Linear(128 * 5, 256)
        self.outlayer3 = nn.Linear(256, 5)  # 是几分类第二个数就改成几,比如我们做5分类的任务,这里就是5

    # 此处的输入为三个,对应三个分支
    def forward(self, input1, input2, input3):
        out1 = self.pooling1_1(self.conv1_1(input1))
        out1 = self.pooling1_1(self.conv1_2(out1))
        out1 = self.pooling1_1(self.conv1_3(out1))
        out1 = self.pooling1_1(self.conv1_4(out1))

        out2 = self.pooling2_1(self.conv2_1(input2))
        out2 = self.pooling2_1(self.conv2_2(out2))
        out2 = self.pooling2_1(self.conv2_3(out2))
        out2 = self.pooling2_1(self.conv2_4(out2))

        out3 = self.pooling3_1(self.conv3_1(input3))
        out3 = self.pooling3_1(self.conv3_2(out3))
        out3 = self.pooling3_1(self.conv3_3(out3))
        out3 = self.pooling3_1(self.conv3_4(out3))
        # 将三个分支的结果在channel维度上合并
        out = torch.cat((out1, out2, out3), dim=1)
        out = out.view(out.size(0), -1)  # [B, C, H, W] --> [B, C*H*W]
        out = self.outlayer1(out)
        out = self.outlayer2(out)
        out = self.outlayer3(out)
        return out


def main():
    device = DEVICE
    model = ThreeInputsNet().to(device)
    loss_fuc = nn.CrossEntropyLoss().to(device)  # 设置损失函数为交叉熵函数
    optimizer = optim.Adam(model.parameters(), lr=LR)  # 设置优化器, weight_decay=0.1


    # 训练模型
    top_acc = 0.0
    for epoch in range(EPOCHS):
        train_loss = 0
        train_acc = 0
        model.train()  # 声明为train模式
        # 使用for循环同时遍历三个dataloader
        for (num, (input_L, label_L)), (num1, (input_R, label_R)), (num2, (input_M, label_M)) in zip(
                enumerate(train_loader_L, start=1),
                enumerate(train_loader_R, start=1),
                enumerate(train_loader_M, start=1)):
            input_L, label_L = input_L.to(device), label_L.to(device)
            input_R, label_R = input_R.to(device), label_R.to(device)
            input_M, label_M = input_M.to(device), label_M.to(device)
            # 比较三个label值是否相等(如果三个通道的输入数据要求有对应关系的话,就在这比较一下)
            assert torch.equal(label_L, label_R) and torch.equal(label_R, label_M), "训练集标签不同"

            y_ = model(input_L, input_R, input_M)
            loss = loss_fuc(y_, label_M)  # 计算损失
            optimizer.zero_grad()  # 梯度清零
            loss.backward()  # 损失回传
            optimizer.step()
            # 记录误差
            train_loss += loss.item()
            # 计算分类的准确率
            out_t = y_.argmax(dim=1)  # 取出预测的最大值
            num_correct = (out_t == label_M).sum().item()
            acc = num_correct / input_M.shape[0]
            train_acc += acc

            # 打印训练过程
            rate = (num1 + 1) / len(train_loader_M)
            a = "*" * int(rate * 50)
            b = '.' * int((1 - rate) * 50)
            print("\rtrain loss:{:^3.0f}%[{}->{}]{:.4f}".format(int(rate * 100), a, b, loss), end="")

        print("\nEpoch:", epoch + 1, 'train_loss:', train_loss / len(train_loader_M), " train_acc:",
              train_acc / len(train_loader_M))

        # 测试模型
        model.eval()  # 声明为test模式
        with torch.no_grad():  # with这一段不需要构建计算图
            # test
            total_correct = 0  # 正确的数量
            total_num = 0
            for (num1, (x1, label1)), (num2, (x2, label2)), (num3, (x3, label3)) in zip(
                    enumerate(test_loader_L, start=1),
                    enumerate(test_loader_R, start=1),
                    enumerate(test_loader_M, start=1)):
                x1, label1 = x1.to(device), label1.to(device)
                x2, label2 = x2.to(device), label2.to(device)
                x3, label3 = x3.to(device), label3.to(device)
                # 判断label是否相等
                assert torch.equal(label1, label2) and torch.equal(label2, label3), "测试标签不同"

                y_ = model(x1, x2, x3)
                pred = y_.argmax(dim=1)  # 选出最大值的索引作为预测的分类结果
                correct = torch.eq(pred, label1).float().sum().item()  # 如果预测值和label值相等则正确数量加一
                total_correct += correct
                total_num += x1.size(0)

            acc = total_correct / total_num

            if acc >= top_acc:
                top_acc = acc
            print("Epoch:", epoch + 1, '; test_top_acc:', top_acc * 100.0, "; test_acc:", acc * 100.0)


if __name__ == '__main__':
    # 超参数
    BATCH_SIZE = 8
    EPOCHS = 200
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # DEVICE = 'cpu'
    LR = 1e-3
    print(DEVICE)

    # 准备数据
    # 数据预处理
    train_transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
    ])

    test_transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
    ])

    # 读取数据
    train_dataset_L = datasets.ImageFolder('自己的路径', train_transform)
    train_dataset_R = datasets.ImageFolder('自己的路径', train_transform)
    train_dataset_M = datasets.ImageFolder('自己的路径', train_transform)

    test_dataset_L = datasets.ImageFolder('自己的路径', test_transform)
    test_dataset_R = datasets.ImageFolder('自己的路径', test_transform)
    test_dataset_M = datasets.ImageFolder('自己的路径', test_transform)
    print(train_dataset_L.class_to_idx)
    print(test_dataset_L.class_to_idx)

    # 导入数据
    seed = random.randint(0, 100)  # 设置随机种子,用于打乱数据
    g = torch.Generator()
    g.manual_seed(seed)  # 如果不加这句,每次启动程序后,随机的结果都是一样的
    train_loader_L = torch.utils.data.DataLoader(train_dataset_L, batch_size=BATCH_SIZE, shuffle=True, generator=g)
    g = torch.Generator()
    g.manual_seed(seed)
    train_loader_R = torch.utils.data.DataLoader(train_dataset_R, batch_size=BATCH_SIZE, shuffle=True, generator=g)
    g = torch.Generator()
    g.manual_seed(seed)
    train_loader_M = torch.utils.data.DataLoader(train_dataset_M, batch_size=BATCH_SIZE, shuffle=True, generator=g)

    # 对于训练集。测试一下是否打乱了数据,以及打乱之后三个通道的标签是否能对应上
    # for (num1, (input_L, label_L)), (num2, (input_R, label_R)), (num3, (input_M, label_M)),\
    #         in zip(enumerate(train_loader_L, start=1),
    #               enumerate(train_loader_R, start=1),
    #               enumerate(train_loader_M, start=1)):
    #
    #     print(num1)
    #     print("label_L:{}".format(label_L))
    #     print("label_R:{}".format(label_R))
    #     print("label_M: {}".format(label_M))
    #     if num1 == 5:
    #         break

    g2 = torch.Generator()
    g2.manual_seed(seed)
    test_loader_L = torch.utils.data.DataLoader(test_dataset_L, batch_size=BATCH_SIZE, shuffle=True, generator=g2)
    g2 = torch.Generator()
    g2.manual_seed(seed)
    test_loader_R = torch.utils.data.DataLoader(test_dataset_R, batch_size=BATCH_SIZE, shuffle=True, generator=g2)
    g2 = torch.Generator()
    g2.manual_seed(seed)
    test_loader_M = torch.utils.data.DataLoader(test_dataset_M, batch_size=BATCH_SIZE, shuffle=True, generator=g2)

    # 对于测试集。测试一下是否打乱了数据,以及打乱之后三个通道的标签是否能对应上
    # for (num1, (x1, label1)), (num2, (x2, label2)), (num3, (x3, label3)),\
    #         in zip(enumerate(test_loader_L, start=1),
    #                enumerate(test_loader_R, start=1),
    #                enumerate(test_loader_M, start=1),):
    #     print(num1)
    #     print("test_l1:{}".format(label1))
    #     print("test_l2:{}".format(label2))
    #     print("test_l3:{}".format(label3))
    #     if num1 == 5:
    #         break
    # 开始训练
    main()

  • 7
    点赞
  • 34
    收藏
    觉得还不错? 一键收藏
  • 10
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 10
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值