第P4周:PyTorch实现猴痘病识别(特点:本地数据、二分类)

1. 数据集说明

本项目是一个二分类项目,数据集的类别为Monkeypox, Others。该数据集的组织形式为MonkeypoxOthers文件夹下分别存放若干图像。

对于这种组织形式的数据,可以采用datasets.ImageFolder()方法提取出全部数据集,提取结果是一个字典,其中数据集以元组的形式存储,即(绝对路径,类别[0 or 1]),可以通过total_datatset.class_to_idx进行查看数据类别。

因为本案例的数据集是关于猴痘病的,视觉上过于恶心,所以不再展示数据集的效果。

2. 模型训练部分

def train(dataloader, model, opt, loss_fn):
    total_data_num = len(dataloader.dataset)
    total_batch = len(dataloader)

    train_acc, train_loss = 0, 0
    for img, label in dataloader:
        img, label = img.to(device()), label.to(device())

        pred = model(img)
        loss = loss_fn(pred, label)

        opt.zero_grad()
        loss.backward()
        opt.step()

        train_acc += (pred.argmax(axis=1) == label).type(torch.float).sum().item()
        train_loss += loss.item()

    return train_acc / total_data_num, train_loss / total_batch

3. 模型测试部分

def test(dataloader, model , loss_fn):
    total_data_num = len(dataloader.dataset)
    total_batch = len(dataloader)

    test_acc, test_loss = 0, 0
    with torch.no_grad():
        for img, label in dataloader:
            img, label = img.to(device()), label.to(device())

            pred = model(img)
            loss = loss_fn(pred, label)

            test_acc += (pred.argmax(axis=1) == label).type(torch.float).sum().item()
            test_loss += loss.item()

    return test_acc / total_data_num, test_loss / total_batch

4. 主函数

import time
import torch
from PIL import Image
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms

if "__main__" == __name__:
    mymodel = MyModel().to(device())

    # Memory --> Dataset
    trans = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    total_datatset = datasets.ImageFolder("./dataset", transform=trans)
    # print(total_datatset.class_to_idx)
    train_size = int(0.8 * len(total_datatset))
    test_size = len(total_datatset) - train_size

    train_data, test_data = torch.utils.data.random_split(total_datatset, lengths=[train_size, test_size])

    # Dataset --> Dataloader
    batch_size = 32
    train_dl =DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True)
    test_dl = DataLoader(test_data, batch_size=batch_size, shuffle=True, drop_last=True)

    # Hyperparameters and logger
    optimizer = torch.optim.SGD(mymodel.parameters(), lr=0.001, momentum=0.9)
    loss_fn = nn.CrossEntropyLoss()
    epochs = 50
    time_stamp = time.strftime("%Y/%m/%d-%H.%M.%S")
    writer = SummaryWriter(comment=time_stamp)

    # plot train_dataloader and test_dataloader
    img_step = 0
    for img, label in train_dl:
        writer.add_images("Train dataloader", img, img_step)
        img_step += 1

    for img, label in test_dl:
        writer.add_images("Test dataloader", img, img_step)
        img_step += 1

    # model train
    for epoch in range(epochs):

        mymodel.train()
        train_acc, train_loss = train(train_dl, mymodel, optimizer, loss_fn)

        mymodel.eval()
        test_acc, test_loss = test(test_dl, mymodel, loss_fn)

        template = "Epoch:{:2d}, Train_acc:{:.2f}%, Train_loss:{:.2f}, Test_acc:{:.2f}%, Test_loss:{:.2f}"
        print(template.format(epoch, train_acc*100, train_loss, test_acc*100, test_loss))

        writer.add_scalar("Train_Loss", train_loss, epoch)
        writer.add_scalar("Test_Loss", test_loss, epoch)
        writer.add_scalar("Train_Accuracy", train_acc, epoch)
        writer.add_scalar("Test_Accuracy", test_acc, epoch)

    writer.close()

5. 运行结果

5.1. 训练集及测试集的损失函数

在这里插入图片描述
在这里插入图片描述

5.2. 训练集及测试集的精度

在这里插入图片描述
在这里插入图片描述

6. 学习心得

(本周太忙了,忙着配置openmmlab和mmcv,地块分割任务数据集的整理都还没有完成,上周预计用argparse模块和shell脚本实现超参数的管理这周没有实现实现了…下次一定!!!)

  • ImageFolder()实现本地组织好的数据集的读取
  • 3
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值