Pokemon数据集实战训练

文章描述了一个使用Python处理Pokemon数据集,通过Resnet18模型进行图像分类的过程。数据集包含约1100张图片,分为训练、验证和测试三部分。代码详细展示了数据加载、预处理、模型构建和训练的步骤,以及模型在训练和验证阶段的性能评估。经过训练,模型达到约90%的正确率。
摘要由CSDN通过智能技术生成

Pokemon数据集包括5类,共计约1100+张图片,每一类图片230+左右,我主要通过两个python文件完成数据集的处理和模型的训练过程。

1.第一个文件:load_dataset.py主要负责加载数据集,并将数据分为train、val、test三类数据,分别是699,233,233条数据。

2.第二个文件:resnet.py主要对Resnet18稍加修改,然后进行模型的训练和输出。

首先第一个文件(load_dataset.py)数据的处理流程为:
1.定义数据类(Pokemon),同时需要重写__len__和__getitem__方法
2.建立数据标识矩阵name2label,提取数据的类别
3.写load_csv将数据的地址和类别存入csv文件,方便以后读取
  (1)判断文件是否存在,如不存在写入数据,存在则读取数据即可
  (2)数据不存在:将数据地址和类别写如csv文件
  (3)数据存在:读取数据并将数据地址和类别分别存入images和labels数组
4.重写__len__返回数据大小
5.重写__getitem__方法返回数据类别,通过遍历iamges中的数据路径,采用PIL.Image.open(x).convert(mode='RGB')包方法读取图片并转化为RGB格式,并采用transforms.Compose做数据增强
6.将数据集划分为train、val和test数据

 具体代码与注释见代码段:

import torch
import os, glob
import random, csv
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import transforms
from PIL import Image
from visdom import Visdom
import numpy as np
import time

class Pokemon(Dataset):
    def __init__(self, root, resize, mode):
        super(Pokemon, self).__init__()

        self.root = root
        self.resize = resize
        self.mode = mode

        # 读取数据并标记数据类型
        self.name2label = {}
        # 需要将文件排序的原因是保证每次运行文件顺序一致,从而避免每次数据和标识混乱
        for name in sorted(os.listdir(os.path.join(self.root))):
            # os.path.isdir判读是否是文件夹,如果不是文件夹,则直接跳过
            if not os.path.isdir(os.path.join(self.root, name)):
                continue
            # 通过获取name2label的数据大小,对每种数据采用0-4进行标识
            self.name2label[name] = len(self.name2label.keys())
        print(self.name2label)

        self.images, self.labels = self.load_csv('images.csv')

        # 将数据划分为train\val\test数据集,分别占比0.6、0.2、0.2
        if self.mode == 'train':
            self.images = self.images[:int(0.6 * len(self.images))]
            self.labels = self.labels[:int(0.6 * len(self.labels))]
        elif self.mode == 'val':
            self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]
            self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))]
        else:
            self.images = self.images[int(0.8 * len(self.images)):]
            self.labels = self.labels[int(0.8 * len(self.labels)):]

    # image_path+label
    def load_csv(self, filename):
        # 假如路径中没存在数据,则将数据写入文件,用于第一次遍历数据的过程
        if not os.path.exists(os.path.join(self.root, filename)):
            # images用于保存图片的路径!
            images = []
            # 保存图片格式为:'data/bulbasaur/00000224.png'
            for name in self.name2label.keys():
                # glob方法可以匹配路径中的符合的数据,然后将其加入images数组中
                # 比如下列三行匹配了以png、jpg、jpeg结尾的图片
                images += glob.glob(os.path.join(self.root, name, '*.png'))
                images += glob.glob(os.path.join(self.root, name, '*.jpg'))
                images += glob.glob(os.path.join(self.root, name, '*.jpeg'))

            print(len(images), images)
            # 打乱图片
            random.shuffle(images)
            # 将图片和对应label信息写入文件
            with open(os.path.join(self.root, filename), mode='w', newline='') as f:
                writer = csv.writer(f)
                for img in images:
                    name = img.split(os.sep)[-2]
                    # 获取图片标识的思路:由于文件的路径是data/bulbasaur/00000224.png
                    # 因此可以通过中间的文件夹姓名与标识匹配,从而确定图片种类
                    label = self.name2label[name]
                    # 保存图片格式为:[data/bulbasaur/00000224.png, 0]
                    writer.writerow([img, label])
                print('write successful to {}'.format(filename))

        # 假设文件已经存在,那就读取文件即可
        images, labels = [], []
        with open(os.path.join(self.root, filename), mode='r') as f:
            reader = csv.reader(f)
            # row标识文件内的每一行
            for row in reader:
                img, label = row
                label = int(label)

                images.append(img)
                labels.append(label)

        # 保证images和labels的长度相同
        assert len(images) == len(labels)
        return images, labels

    def denormalize(self, x_hat):
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
        # x_hat = (x-mean)/std 因此反向计算还原x的数据
        # 但是x_hat的形状是[3,h,w],因此需要将mean和std的形状匹配
        mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
        std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
        x = x_hat * std + mean
        return x

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img, label = self.images[idx], self.labels[idx]
        tf = transforms.Compose([
            lambda x: Image.open(x).convert(mode='RGB'),
            transforms.Resize(int(self.resize * 1.25)),
            transforms.RandomRotation(15),
            transforms.CenterCrop(self.resize),  # 中心裁剪,相当于将旋转后的背景填充白色
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

        img = tf(img)
        label = torch.tensor(label)

        return img, label

第二个文件(resnet.py)通过修改原始的Resnet18完成模型的搭建和分析任务,具体流程为:
1.构建Resnet中的循环模块,shortcut模块,便于网络中使用
2.构建Resnet网络
3.构造训练函数,并在训练函数中调用测试函数,并将正确率最高的模型保存在本地
4.读取正确率最高的模型,采用测试数据集计算正确率

 具体代码与注释见代码段:

import torch
from torch import nn
from torch.nn import functional as F
from load_dataset import Pokemon
import torchvision
from torch.utils.data import Dataset, DataLoader
from visdom import Visdom


# 构建Resnet中的循环模块,shortcut模块
class Resblk(nn.Module):
    def __init__(self, ch_in, ch_out, stride=1):
        super(Resblk, self).__init__()
        self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(ch_out)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(ch_out)

        self.extra = nn.Sequential()
        if ch_out != ch_in:
            self.extra = nn.Sequential(
                nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride),
                nn.BatchNorm2d(ch_out)
            )

    def forward(self, x):
        out = self.bn2(self.conv2(self.relu(self.bn1(self.conv1(x)))))
        out = F.relu(out + self.extra(x))
        return out


# 构建网络模型
class Restnet18(nn.Module):
    def __init__(self):
        super(Restnet18, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=3, padding=0),
            nn.BatchNorm2d(16)
        )
        self.blk1 = Resblk(16, 32, stride=3)
        self.blk2 = Resblk(32, 64, stride=3)
        self.blk3 = Resblk(64, 128, stride=2)
        self.blk4 = Resblk(128, 256, stride=2)
        self.linear = nn.Linear(256 * 3 * 3, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.blk1(x)
        x = self.blk2(x)
        x = self.blk3(x)
        x = self.blk4(x)
        x = x.view(x.size(0), -1)
        x = self.linear(x)
        return x


# data = torch.randn(2, 3, 224, 224)
# model = Restnet18()
# out = model(data)
# print(out.shape)

# 构建测试函数
def val(val_loader, device, model):
    model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)
            logit = model(data)
            pred = torch.argmax(logit, dim=1)
            correct += torch.eq(pred, target).sum().float().item()
        correct /= len(val_loader.dataset)  # 注意数据的长度是len(val_loader.dataset)不是len(data)
    return correct


# 构建训练函数
def train(train_loader, val_loader, epochs, device, model, criteon, optimizer, viz):
    best_acc, best_epoch = 0, 0
    global_step = 0  # 用于绘制图片的x坐标
    viz.line([0], [-1], win='loss', opts=dict(title='loss'))
    viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc'))

    for epoch in range(epochs):
        train_correct = 0
        # 其中enumerate函数返回编号和数据两个内容,step表示编号,(data,target)表示数据和标识
        # 此时的step的数值=训练集的大小/batch的大小,例如训练集的step = 699/32 = 22
        for step, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            logit = model(data)
            loss = criteon(logit, target)

            train_pred = torch.argmax(logit, dim=1)
            # 根据输出,由于一个step训练27张图片,因此需要将每个step内与图片标识一致的数量求和之后标识正确率
            train_correct += torch.eq(train_pred, target).sum().float().item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            viz.line([loss.item()], [global_step], win='loss', update='append')
            global_step += 1
            
        train_correct /= len(train_loader.dataset)
        print('epoch:{}, train_acc:{}'.format(epoch, train_correct))

        # 调用val函数对于每次epoch计算模型的正确率,并覆盖保存正确率最高的模型
        if epoch % 1 == 0:
            val_acc = val(val_loader, device, model)
            print('epoch:{}, correct:{}'.format(epoch, val_acc))
            if (val_acc > best_acc):
                best_acc = val_acc
                best_epoch = epoch
                torch.save(model.state_dict(), 'data/best.mdl')
            viz.line([val_acc], [global_step], win='val_acc', update='append')

    return best_acc, best_epoch


def main():
    device = torch.device('cuda')
    learning_rate = 0.001
    epochs = 10

    # 加载数据
    path = 'data'
    batch_size = 32
    train_db = Pokemon(path, 224, mode='train')
    val_db = Pokemon(path, 224, mode='val')
    test_db = Pokemon(path, 224, mode='test')

    print(len(train_db))
    print(len(val_db))
    print(len(test_db))

    # 通过将数据按照batch_size大小划分,例如train_db总共699个数据,每一个batch_size是32,那么总共就有22个batch
    train_loader = DataLoader(train_db, batch_size=batch_size, num_workers=4, shuffle=True)
    val_loader = DataLoader(val_db, batch_size=batch_size, num_workers=2)
    test_loader = DataLoader(test_db, batch_size=batch_size, num_workers=2)

    # 加载模型
    model = Restnet18().to(device)
    criteon = nn.CrossEntropyLoss().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # 加载visdom
    viz = Visdom()

    # 将训练集和验证集带入模型进行训练,并保存正确率最高的模型
    best_acc, best_epoch = train(train_loader, val_loader, epochs, device, model, criteon, optimizer, viz)
    print('best_epoch:{},best_acc:{}'.format(best_epoch, best_acc))

    # 加载正确率最高的模型,并将测试集数据带入验证正确率
    model.load_state_dict(torch.load('data/best.mdl'))
    print('loaded successful')
    test_correct = val(test_loader, device, model)
    print('test correct is {}'.format(test_correct))


if __name__ == '__main__':
    main()

下图是训练样本loss的下降曲线,基本处于下降趋势,但是后期浮动较大

 下图是测试集acc的曲线,发现训练三四个epoch之后正确率基本不变

最终模型导出的正确率大约在90%左右,效果相对可以。

数据集提取地址:

链接:https://pan.baidu.com/s/1lG32vBzQJeIrrx0lx8W1cQ 
提取码:m96e

Pokemon数据集是一个包含了关于Pokemon(宠物小精灵)的信息的数据集。这个数据集中收集了数百种Pokemon的属性、能力、技能、种族值等详细信息,可以用来进行各种数据分析和机器学习任务。 这个数据集中的属性信息包括每只Pokemon的种类、身高、重量、颜色等等。能力信息包括每只Pokemon的生命值、攻击力、防御力、速度等等。技能信息包括每只Pokemon可以使用的特殊技能、物理技能和状态技能。种族值则是一种用来表示Pokemon基础能力值的指标,能够影响Pokemon在战斗中的表现。 通过对Pokemon数据集进行分析,我们可以了解每个种类Pokemon的平均属性值、能力分布和技能种类等等。比如,我们可以分析哪些Pokemon的攻击力和速度高,哪些Pokemon的特殊防御力比较低,以及它们之间的关联性等。这对于创作游戏策略、进行角色平衡的调整等方面都有着重要的作用。 此外,Pokemon数据集还可以用于机器学习的任务。我们可以利用这些数据训练模型来预测Pokemon的属性、种族值等信息,或者构建一个可以根据Pokemon的属性和技能来推荐最佳战斗队伍的模型。这些模型可以在游戏中用于AI对战、自动战斗等功能。 总之,Pokemon数据集是一个提供了Pokemon相关信息的数据集,对于理解和分析Pokemon的属性、能力以及进行相关的机器学习任务具有重要意义。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值