Pytorch实战宝可梦分类-自定义数据集完成宝可梦分类案例分步解析

Pytorch实战-自定义数据集完成宝可梦分类案例分步解析

前言、准备工作

本案例需要导入的包, 没有下载的通过pip install来下载

部分库的详细安装教程可以看我之前的文章
Visdom的下载与踩坑
pytorch的安装 基于anaconda

import torch
import os
import glob
import random, csv, time
from torch.utils.data import Dataset, DataLoader
from torch import nn, optim
import torchvision
from torchvision import transforms
from torchvision.models import resnet18
import visdom
from PIL import Image

一、数据集介绍

自定义的数据集内容如下

  • 皮卡丘:234
  • 超梦:239
  • 杰尼龟:223
  • 小火龙:238
  • 妙蛙种子:234

已经将照片存储至相应的文件夹下, 如下
在这里插入图片描述

二、自定义数据集分步解读

Dataset基础文章: Pytorch 快速详解如何构建自己的Dataset完成数据预处理(附详细过程)

自定义的Dataset大致框架如下, 这方面不太懂的可以看看我之前的文章.

class Pokemon(Dataset):
    def __init__(self):
        pass

    def __len__(self):
        pass

    def __getitem__(self, item):
        pass

1.观察数据集

观察一下数据集中的图片,发现图片的类型有jpg,png,jpeg, 并且图片的大小各不相同,因此我们需要对训练的图片做resize等操作
在这里插入图片描述
在这里插入图片描述

2.类别映射关系构建

    def __init__(self, root, resize, mode):
        super(Pokemon, self).__init__()
        self.root = root
        self.resize = resize
        self.mode = mode
        pass
  • root: 数据集所在的根目录
  • resize 数据集中提供的数据统一大小
  • mdoe 读取数据集时的模式 train,val,test

因为在模型中label需要转换为相应的int形, 我希望初始化函数能自动的给出的root路径里的文件夹中读取出namelabel的映射关系,这更符合应用中的实际情况
简单的实现映射效果如下代码即可

  ...: dic = {}
  ...: con = 0
  ...: for name in os.listdir(root):
  ...:     dic[name] = con 
  ...:     con += 1
  ...: print(dic)
{'bulbasaur': 0, 'charmander': 1, 'mewtwo': 2, 'pikachu': 3, 'squirtle': 4}

考虑到实际场景中可能出现的状况,使用下面的代码来构建映射

 self.name2label = {} # "sq...":0
 for name in sorted(os.listdir(os.path.join(root))):
     if not os.path.isdir(os.path.join(root, name)):
         continue

     self.name2label[name] = len(self.name2label.keys())

3. 初始化图片读取

因为这里我们希望得到的是一个包含imgpath,label的对象,所以在第一次运行的时候我们可以自定义函数将这样的关系存储至一个csv文件中

glob 文件名模式匹配: 通过指定的筛选规则返回指定路径下的所有满足规则的文件 并且可以进行迭代

存储imageslabel的关系

    def load_csv(self, filename):
        # filename指的是csv的名字,这里将映射的csv文件存储在root目录下,如果存在则跳过
        if not os.path.exists(os.path.join(self.root, filename)):
            images = []
            for name in self.name2label.keys():
                images += glob.glob(os.path.join(self.root, name, '*.png'))
                images += glob.glob(os.path.join(self.root, name, '*.jpg'))
                images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
            # 打乱顺序
            random.shuffle(images)
            with open(os.path.join(self.root, filename), mode='w', newline='') as f:
                writer = csv.writer(f)
                for img in images:
                    # images中已经包含了label,这里通过split来读取出来
                    # imges: 如'pokemon\\bulbasaur\\00000000.png'
                    name = img.split(os.sep)[-2]
                    label = self.name2label[name]
                    # 按 'pokemon\\bulbasaur\\00000000.png', 0 写入csv中
                    writer.writerow([img, label])
                print('writen into csv file successful:', filename)

images.csv的结果如下图所示
在这里插入图片描述
从csv中读取映射关系,便于加载数据集,最后返回imageslabels,在init函数中存储为类变量

        # read from csv file
        images, labels = [], []
        with open(os.path.join(self.root, filename)) as f:
            reader = csv.reader(f)
            for row in reader:
                # 'pokemon\\bulbasaur\\00000000.png', 0
                img, label = row
                images.append(img)
                labels.append(int(label))

        assert len(images) == len(labels)
        return images, labels

4.划分数据集

这里按照60,20,20的比例来分割数据集为train,test,val

        if mode=='train': # 60%
            self.images = self.images[:int(0.6*len(self.images))]
            self.labels = self.labels[:int(0.6*len(self.labels))]
        elif mode=='val': # 20% = 60%->80%
            self.images = self.images[int(0.6*len(self.images)):int(0.8*len(self.images))]
            self.labels = self.labels[int(0.6*len(self.labels)):int(0.8*len(self.labels))]
        else: # 20% = 80%->100%
            self.images = self.images[int(0.8*len(self.images)):]
            self.labels = self.labels[int(0.8*len(self.labels)):]

5.getitem

对获取到的imglabel做相应的变换处理,并转换为tensor对象

    def __getitem__(self, idx):
        img, label = self.images[idx], self.labels[idx]
        tf = transforms.Compose([
            lambda x: Image.open(x).convert('RGB'),  # string path= > image data
            transforms.Resize((int(self.resize * 1.25), int(self.resize * 1.25))),
            transforms.RandomRotation(15),
            transforms.CenterCrop(self.resize),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

        img = tf(img)
        label = torch.tensor(label)
        return img, label

6. 查看自定义数据集的读取效果

这里借助visdom库来可视化查看图片, 因为对于图像数据它可以直接从tensor对象来转化
Visdom的下载使用方法请看本链接
首先在命令行启动 python -m visdom.server
创建test函数来观察自定义数据集的效果, 这里将初始batch设置为32张图片
使用Dataloader来创建一个loader, 其好处是可以指定batch并且可以shuffle数据

def test():
    viz = visdom.Visdom()
    start = time.time()
    db = Pokemon(r'D:\Source\Datasets\pokeman', 64, 'train')
    loader = DataLoader(db, batch_size=32, shuffle=True)
    for x, y in loader:
        viz.images(x, nrow=8, win='batch', opts=dict(title='batch'))
        viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))

        time.sleep(10)
    print('time:', time.time() - start)

在这里插入图片描述
这里看到的效果很诡异, 这是因为我们在__getitem__函数中添加了transform的操作

注: 其中的meanstd参数来自ImageNet的均值和标准差。使用Imagenet的均值和标准差是一种常见的做法。它们是根据数百万张图像计算得出的

在这里插入图片描述
这里影响视觉效果的主要是Normalize操作,因此我们可以写一个函数来起到denormalize的效果

    def denormalize(self, x_hat):

        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]

        # x_hat = (x-mean)/std
        # x = x_hat*std + mean
        # x: [c, h, w]
        # mean: [3] => [3, 1, 1]
        mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
        std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
        x = x_hat * std + mean

        return x

并且将可视化的对象修改为viz.images(db.denormalize(x), nrow=8, win='batch', opts=dict(title='batch'))
可以看到正常的图片
在这里插入图片描述

二、自定义数据集快速构建法

如果数据集的存放结构比较整齐,类似下图这样

在这里插入图片描述
就可以用ImageFolder一行代码来代替所有的步骤, 仅仅需要事先指定一下transform的内容, 这里就简单的做个resize
ImageFolder假设所有的文件按文件夹保存,每个文件夹下存储同一个类别的图片,文件夹名为类名

    tf = transforms.Compose([
                    transforms.Resize((64,64)),
                    transforms.ToTensor(),
    ])
    db = torchvision.datasets.ImageFolder(root=r'D:\Source\Datasets\pokeman', transform=tf)

结果十分顺利
在这里插入图片描述
并且ImageFolder类已经内置好了方法构建出了类别与文件夹名的映射关系, 查看方式如下

print(db.class_to_idx)

{'bulbasaur': 0, 'charmander': 1, 'mewtwo': 2, 'pikachu': 3, 'squirtle': 4}

三、迁移学习训练分类器(自定义数据类实现)

这里使用resnet18来训练,导入方式如下

from    torchvision.models import resnet18

1.导入并修改resnet18

    trained_model = resnet18(pretrained=True)

这里需要设置参数pretrained=True, 获取已经预训练好的参数
对于resnet的最后一层我们需要手动的做一些修改使其能够适合我们自定义的数据集
因为torch没有提供Flatten层, 这里我们可以手动写一个Flatten类完成拉平的操作, 其核心就是用view函数来修改维度

class Flatten(nn.Module):

    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        shape = torch.prod(torch.tensor(x.shape[1:])).item()
        return x.view(-1, shape)

如果不熟悉resnet, 我们可以输出一下前面17层的输出, 来决定如何修改网络
trained_model.children())[:-1]来获取网络的前17层

trained_model = resnet18(pretrained=True)
model = nn.Sequential(*list(trained_model.children())[:-1],  # [b, 512, 1, 1]
                          Flatten(),  # [b, 512, 1, 1] => [b, 512]
                          ).to(device)

先输出一下网络结构是怎么样的

Sequential(
  (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (5): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (6): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (7): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (8): AdaptiveAvgPool2d(output_size=(1, 1))
  (9): Flatten()
)
time: 3.332292318344116

Process finished with exit code 0

这里可以随便创建一个适合的维度来看看输出什么

x = torch.randn(2, 3, 64, 64).to(device)
print(model(x).shape)
torch.Size([2, 512])

所以添加一个线性层即可

    trained_model = resnet18(pretrained=True)
    model = nn.Sequential(*list(trained_model.children())[:-1],  # [b, 512, 1, 1]
                          Flatten(),  # [b, 512, 1, 1] => [b, 512]
                          nn.Linear(512, 5)
                          ).to(device)

2.编写函数计算准确率

本函数不难, 使用total计算当前loader的总长度, 通过torch.eq(pred, y).sum()得到预测中正确的数量, 最终返回准确率

argmax等基本函数不懂的可以看我之前的文章 pytorch常用函数与基本特性总结大全

def evalute(model, loader):
    model.eval()
    
    correct = 0
    total = len(loader.dataset)

    for x,y in loader:
        x,y = x.to(device), y.to(device)
        with torch.no_grad():
            logits = model(x)
            pred = logits.argmax(dim=1)
        correct += torch.eq(pred, y).sum().float().item()

    return correct / total

3.编写训练函数

这里的函数比较通用

    '''train'''
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criteon = nn.CrossEntropyLoss()
    best_acc, best_epoch = 0, 0
    global_step = 0

    viz.line([0], [-1], win='loss', opts=dict(title='loss'))
    viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc'))

    for epoch in range(epochs):

        for step, (x, y) in enumerate(train_loader):
            # x: [b, 3, 224, 224], y: [b]
            x, y = x.to(device), y.to(device)

            model.train()
            logits = model(x)
            loss = criteon(logits, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            viz.line([loss.item()], [global_step], win='loss', update='append')
            # print(f'global_step: {global_step}, loss: {loss.item()}')

            global_step += 1

        if epoch % 1 == 0:

            val_acc = evalute(model, val_loader)
            if val_acc > best_acc:
                best_epoch = epoch
                best_acc = val_acc

                torch.save(model.state_dict(), 'best.mdl')

                viz.line([val_acc], [global_step], win='val_acc', update='append')
            print(f'epoch: {epoch}, val_acc: {val_acc}')
    print('best acc:', best_acc, 'best epoch:', best_epoch)

    model.load_state_dict(torch.load('best.mdl'))
    print('loaded from ckpt!')

    test_acc = evalute(model, test_loader)
    print('test acc:', test_acc)

在第3轮的时候就在验证集上达到了93%的准确率

epoch: 0, val_acc: 0.4334763948497854
epoch: 1, val_acc: 0.7167381974248928
epoch: 2, val_acc: 0.8969957081545065
epoch: 3, val_acc: 0.9356223175965666
epoch: 4, val_acc: 0.9055793991416309
epoch: 5, val_acc: 0.7939914163090128
epoch: 6, val_acc: 0.8540772532188842
epoch: 7, val_acc: 0.9055793991416309
epoch: 8, val_acc: 0.9184549356223176
epoch: 9, val_acc: 0.944206008583691
best acc: 0.944206008583691 best epoch: 9
loaded from ckpt!
test acc: 0.9273504273504274
time: 458.2812337875366

Process finished with exit code 0

loss和acc的变换过程
在这里插入图片描述

三、迁移学习训练分类器(ImageFolder数据类实现)

刚刚的实现方式是使用自定义的Pokemon类来实现的, 因为这次数据集的存储方式十分整齐,所以也可以用ImageFolder来实现, 仅仅需要手动划分一下即可
使用random_split函数来划分db, 事先计算一下划分的大小即可

    resize = 224
    tf = transforms.Compose([
        transforms.Resize((int(resize * 1.25), int(resize * 1.25))),
        transforms.RandomRotation(15),
        transforms.CenterCrop(resize),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    db = torchvision.datasets.ImageFolder(root=r'D:\Source\Datasets\pokeman', transform=tf)
    train_size, val_size = int(len(db) * 0.6), int(len(db) * 0.2)
    test_size = len(db) - train_size - val_size
    train_db, val_db, test_db = torch.utils.data.random_split(dataset=db,
                                                              lengths=[train_size, val_size, test_size])
    train_loader = DataLoader(train_db, batch_size=batchsz, shuffle=True,
                              num_workers=0)
    val_loader = DataLoader(val_db, batch_size=batchsz, num_workers=0)
    test_loader = DataLoader(test_db, batch_size=batchsz, num_workers=0)

四、测试单张图片

这里需要事先存储一下之前的labelint的映射关系
img_label = {'bulbasaur': 0, 'charmander': 1, 'mewtwo': 2, 'pikachu': 3, 'squirtle': 4}
并且在网络测试的时候要记得使用model.eval()

# -*- coding: utf-8 -*-
# @Time    : 2022/2/3 14:24
# @Author  : JokerTong
# @File    : test42_宝可梦测试.py
import torch
from torch import nn
from torchvision.models import resnet18
from torchvision import transforms
from PIL import Image
import visdom


class Flatten(nn.Module):

    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        shape = torch.prod(torch.tensor(x.shape[1:])).item()
        return x.view(-1, shape)


img_label = {'bulbasaur': 0, 'charmander': 1, 'mewtwo': 2, 'pikachu': 3, 'squirtle': 4}
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
trained_model = resnet18(pretrained=True)
model = nn.Sequential(*list(trained_model.children())[:-1],
                      Flatten(),
                      nn.Linear(512, 5)
                      ).to(device)
model.load_state_dict(torch.load('best.mdl'))
resize = 224
tf = transforms.Compose([
    transforms.Resize((int(resize * 1.25), int(resize * 1.25))),
    transforms.RandomRotation(15),
    transforms.CenterCrop(resize),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])
img = Image.open('test_pikaqiu.jpg')
img_tensor = tf(img)
img_tensor.unsqueeze_(0)
img_tensor = img_tensor.to(device)
model.eval()
out = model(img_tensor)
predict = list(img_label.keys())[torch.argmax(out).item()]
viz = visdom.Visdom()
viz.images(transforms.ToTensor()(img), win='image', opts=dict(title='image'))
viz.text('预测结果:' + predict, win='predict', opts=dict(title='predict'))
print(predict)

Setting up a new session...
pikachu

Process finished with exit code 0

可视化结果如下
在这里插入图片描述

全代码

# -*- coding: utf-8 -*-
# @Time    : 2022/2/1 11:36
# @Author  : JokerTong
# @File    : test41_自定义数据集.py
import torch
import os
import glob
import random, csv, time
from torch.utils.data import Dataset, DataLoader
from torch import nn, optim
import torchvision
from torchvision import transforms
from torchvision.models import resnet18
import visdom
from PIL import Image


class Pokemon(Dataset):
    def __init__(self, root, resize, mode):
        super(Pokemon, self).__init__()
        self.root = root
        self.resize = resize
        self.mode = mode
        self.name2label = {}  # "sq...":0
        for name in sorted(os.listdir(os.path.join(root))):
            if not os.path.isdir(os.path.join(root, name)):
                continue

            self.name2label[name] = len(self.name2label.keys())
        # print('name2label create success!', self.name2label)
        # image, label
        self.images, self.labels = self.load_csv('images.csv')
        # print('load images.csv success!', len(self.images))
        print(len(self.images))
        if mode == 'train':  # 60%
            self.images = self.images[:int(0.6 * len(self.images))]
            self.labels = self.labels[:int(0.6 * len(self.labels))]
        elif mode == 'val':  # 20% = 60%->80%
            self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]
            self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))]
        else:  # 20% = 80%->100%
            self.images = self.images[int(0.8 * len(self.images)):]
            self.labels = self.labels[int(0.8 * len(self.labels)):]

    def load_csv(self, filename):
        # filename指的是csv的名字,这里将映射的csv文件存储在root目录下
        if not os.path.exists(os.path.join(self.root, filename)):
            images = []
            for name in self.name2label.keys():
                images += glob.glob(os.path.join(self.root, name, '*.png'))
                images += glob.glob(os.path.join(self.root, name, '*.jpg'))
                images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
            # 打乱顺序
            random.shuffle(images)
            with open(os.path.join(self.root, filename), mode='w', newline='') as f:
                writer = csv.writer(f)
                for img in images:
                    # images中已经包含了label,这里通过split来读取出来
                    # imges: 如'pokemon\\bulbasaur\\00000000.png'
                    name = img.split(os.sep)[-2]
                    label = self.name2label[name]
                    # 按 'pokemon\\bulbasaur\\00000000.png', 0 写入csv中
                    writer.writerow([img, label])
                print('writen into csv file successful:', filename)

        # read from csv file
        images, labels = [], []
        with open(os.path.join(self.root, filename)) as f:
            reader = csv.reader(f)
            for row in reader:
                # 'pokemon\\bulbasaur\\00000000.png', 0
                img, label = row
                images.append(img)
                labels.append(int(label))

        assert len(images) == len(labels)
        return images, labels

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img, label = self.images[idx], self.labels[idx]
        tf = transforms.Compose([
            lambda x: Image.open(x).convert('RGB'),  # string path= > image data
            transforms.Resize((int(self.resize * 1.25), int(self.resize * 1.25))),
            transforms.RandomRotation(15),
            transforms.CenterCrop(self.resize),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

        img = tf(img)
        label = torch.tensor(label)
        return img, label

    def denormalize(self, x_hat):

        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]

        # x_hat = (x-mean)/std
        # x = x_hat*std = mean
        # x: [c, h, w]
        # mean: [3] => [3, 1, 1]
        mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
        std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
        # print(mean.shape, std.shape)
        x = x_hat * std + mean

        return x


class Flatten(nn.Module):

    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        shape = torch.prod(torch.tensor(x.shape[1:])).item()
        return x.view(-1, shape)


def test():
    # db = Pokemon(r'D:\Source\Datasets\pokeman', 128, 'train')
    tf = transforms.Compose([
        transforms.Resize((128, 120)),
        transforms.ToTensor(),
    ])
    db = torchvision.datasets.ImageFolder(root=r'D:\Source\Datasets\pokeman', transform=tf)
    print(db.class_to_idx)
    loader = DataLoader(db, batch_size=32, shuffle=True)
    for x, y in loader:
        viz.images(x, nrow=8, win='batch', opts=dict(title='batch'))
        viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))

        # time.sleep(10)


def evalute(model, loader):
    model.eval()

    correct = 0
    total = len(loader.dataset)

    for x, y in loader:
        x, y = x.to(device), y.to(device)
        with torch.no_grad():
            logits = model(x)
            pred = logits.argmax(dim=1)
        correct += torch.eq(pred, y).sum().float().item()

    return correct / total


if __name__ == '__main__':
    '''init'''
    start = time.time()
    viz = visdom.Visdom()
    # test()
    batchsz = 32
    lr = 1e-3
    epochs = 10
    device = torch.device('cuda')
    torch.manual_seed(1234)
    '''自定义数据集'''
    # # 获取数据集
    # train_db = Pokemon(r'D:\Source\Datasets\pokeman', 224, mode='train')
    # val_db = Pokemon(r'D:\Source\Datasets\pokeman', 224, mode='val')
    # test_db = Pokemon(r'D:\Source\Datasets\pokeman', 224, mode='test')
    # # 创建loader对象
    # train_loader = DataLoader(train_db, batch_size=batchsz, shuffle=True,
    #                           num_workers=0)
    # val_loader = DataLoader(val_db, batch_size=batchsz, num_workers=0)
    # test_loader = DataLoader(test_db, batch_size=batchsz, num_workers=0)
    '''ImageFolder数据集'''
    resize = 224
    tf = transforms.Compose([
        transforms.Resize((int(resize * 1.25), int(resize * 1.25))),
        transforms.RandomRotation(15),
        transforms.CenterCrop(resize),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    db = torchvision.datasets.ImageFolder(root=r'D:\Source\Datasets\pokeman', transform=tf)
    train_size, val_size = int(len(db) * 0.6), int(len(db) * 0.2)
    test_size = len(db) - train_size - val_size
    train_db, val_db, test_db = torch.utils.data.random_split(dataset=db,
                                                              lengths=[train_size, val_size, test_size])
    train_loader = DataLoader(train_db, batch_size=batchsz, shuffle=True,
                              num_workers=0)
    val_loader = DataLoader(val_db, batch_size=batchsz, num_workers=0)
    test_loader = DataLoader(test_db, batch_size=batchsz, num_workers=0)

    # 创建resnet18
    trained_model = resnet18(pretrained=True)
    model = nn.Sequential(*list(trained_model.children())[:-1],  # [b, 512, 1, 1]
                          Flatten(),  # [b, 512, 1, 1] => [b, 512]
                          nn.Linear(512, 5)
                          ).to(device)
    # x = torch.randn(2, 3, 224, 224).to(device)
    # x = torch.randn(2, 3, 64, 64).to(device)
    # print(model)
    # print(model(x).shape)
    '''train'''
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criteon = nn.CrossEntropyLoss()
    best_acc, best_epoch = 0, 0
    global_step = 0

    viz.line([0], [-1], win='loss', opts=dict(title='loss'))
    viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc'))

    for epoch in range(epochs):

        for step, (x, y) in enumerate(train_loader):
            # x: [b, 3, 224, 224], y: [b]
            x, y = x.to(device), y.to(device)

            model.train()
            logits = model(x)
            loss = criteon(logits, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            viz.line([loss.item()], [global_step], win='loss', update='append')
            # print(f'global_step: {global_step}, loss: {loss.item()}')

            global_step += 1

        if epoch % 1 == 0:

            val_acc = evalute(model, val_loader)
            if val_acc > best_acc:
                best_epoch = epoch
                best_acc = val_acc

                torch.save(model.state_dict(), 'best.mdl')

                viz.line([val_acc], [global_step], win='val_acc', update='append')
            print(f'epoch: {epoch}, val_acc: {val_acc}')
    print('best acc:', best_acc, 'best epoch:', best_epoch)

    model.load_state_dict(torch.load('best.mdl'))
    print('loaded from ckpt!')

    test_acc = evalute(model, test_loader)
    print('test acc:', test_acc)
    print('time:', time.time() - start)

  • 8
    点赞
  • 46
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Joker-Tong

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值