pytorch重写 Dataset

Pytorch 继承 Dataset 加载自己定义的数据

首先介绍自己的 Mydataset

import os
import glob
import csv
import random

from PIL import Image
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

class Mydataset(Dataset):
    def __init__(self, root, resize, mode):
        super(Mydataset, self).__init__()
        self.root = root
        self.resize = resize

        self.name2label = {}  # 0,1,2 ...
        for name in sorted(os.listdir(os.path.join(root))):
            if not os.path.isdir(os.path.join(root, name)):
                continue

            self.name2label[name] = len(self.name2label.keys())
        print(self.name2label)
        self.images, self.labels = self.load_csv('imagess.csv')

        if mode == 'train':  # %60 = %0->%60
            self.images = self.images[:int(0.6 * len(self.images))]
            self.labels = self.labels[:int(0.6 * len(self.labels))]
        elif mode == 'val':  # %20 = %60->%80
            self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]
            self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))]
        else:  # %20 = %80->%100
            self.images = self.images[int(0.8 * len(self.images)):]
            self.labels = self.labels[int(0.8 * len(self.labels)):]

    def load_csv(self, filename):
        if not os.path.exists(os.path.join(self.root, filename)):
            images = []
            for name in self.name2label.keys():
                images += glob.glob(os.path.join(self.root, name, '*.png'))
                images += glob.glob(os.path.join(self.root, name, '*.jpg'))
                images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
            print(len(images), images)

            random.shuffle(images)
            with open(os.path.join(self.root, filename), mode='w', newline='') as f:
                write = csv.writer(f)
                for img in images:
                    name = img.split(os.sep)[-2]
                    label = self.name2label[name]
                    write.writerow([img, label])
                print('writen into csv file:', filename)

        # read csv
        images, labels = [], []
        with open(os.path.join(self.root, filename)) as f:
            reader = csv.reader(f)
            for row in reader:
                img, label = row
                label = int(label)

                images.append(img)
                labels.append(label)

        assert len(images) == len(labels)
        return images, labels

    def __len__(self):
        return len(self.images)



    def __getitem__(self, idx):
        # idx-[0->len(images)]
        img, label = self.images[idx], self.labels[idx]
        tf = transforms.Compose([
            lambda x: Image.open(x).convert('RGB'),
            transforms.Resize((int(self.resize * 1.25), int(self.resize * 1.25))),
            transforms.RandomRotation(15),
            transforms.CenterCrop(self.resize),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

        img = tf(img)
        label = torch.tensor(label)
        return img, label

    def denormalize(self, x_hat):
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
        # x_hat = (x - mean) / std
        # x = x_hat * std + mean
        # x:[x,h,w]
        # mean: [3] -> [3, 1, 1]

        mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
        std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
        x = x_hat * std + mean
        return x


def main():
    import visdom
    import time
    import torchvision

    viz = visdom.Visdom()

    transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor()
    ])

    tmp = torchvision.datasets.ImageFolder(root='dataset', transform=transform)
    loader = DataLoader(tmp, batch_size=32, shuffle=True)

    for x, y in loader:
        viz.images(x, nrow=8, win='batch', opts=dict(title='batch'))
        viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))
        time.sleep(10)
 
if __name__ == "__main__":
    main()

基于 resnet18 如何加载数据训练,首先完成一个 Flatten.py 的函数

import torch
import torch.nn as nn

import matplotlib.pyplot as plt


class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        shape = torch.prod(torch.tensor(x.shape[1:])).item()
        return x.view(-1, shape)


def plot_image(img, label, name):
    fig = plt.figure()
    for i in range(6):
        plt.subplot(2,3, i+1)
        plt.tight_layout()
        plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none')
        plt.title('{}: {}'.format(name, label[i].item()))
        plt.xticks([])
        plt.yticks([])
    plt.show()

完成 train_resnrt18.py 训练程序

import torch
import visdom
import torch.nn as nn
import torch.optim
from mydataset import Mydataset
from torch.utils.data import Dataset, DataLoader

from Flatten import Flatten
from torchvision.models.resnet import resnet18

batchsize = 32
learning_rate = 1e-5
epoches = 1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


train_db = Mydataset('datasets', 32, mode='train')
val_db = Mydataset('datasets', 32, mode='val')
test_db = Mydataset('datasets', 32, mode='test')


train_loader = DataLoader(train_db, batch_size=batchsize, shuffle=True, num_workers=4)
val_loader = DataLoader(val_db, batch_size=batchsize, num_workers=2)
test_loader = DataLoader(test_db, batch_size=batchsize, num_workers=2)

# 训练模型

viz = visdom.Visdom()


def evaluate(model, loader):
    correct = 0
    total = len(loader.dataset)
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        with torch.no_grad():
            logits = model(x)
            pred = logits.argmax(dim=1)
        correct += torch.eq(pred, y).sum().float().item()
    return correct/total


def main():
    model = resnet18(pretrained=True)  # 比较好的 model
    model = nn.Sequential(*list(model.children())[:-1],  # [b, 512, 1, 1] -> 接全连接层
                          Flatten(),  # [b, 512, 1, 1] -> [b, 512]
                          nn.Linear(512, 2)).to(device)  # 添加全连接层

    # x = torch.randn(2, 3, 224, 224)
    # print(model(x).shape)
    # 定义损失函数
    criterion = nn.CrossEntropyLoss()
    # 定义迭代参数的算法
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    best_acc, best_epoch = 0, 0
    global_step = 0
    viz.line([0], [-1], win='loss', opts=dict(title='loss'))
    viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc'))
    for epoch in range(epoches):
        for step, (x, y) in enumerate(train_loader):
            viz.images(train_db.denormalize(x), nrow=8, win='batch', opts=dict(title='batch'))
            viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))
            x, y = x.to(device), y.to(device)
            model.train()
            logits = model(x)
            loss = criterion(logits, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            viz.line([loss.item()], [global_step], win='loss', update='append')
            global_step += 1
        if epoch % 1 == 0:
            val_acc = evaluate(model, val_loader)
            if val_acc > best_acc:
                best_epoch = epoch
                best_acc = val_acc
                viz.line([val_acc], [global_step], win='val_acc', update='append')



    print("best acc:", best_acc, "best epoch:", best_epoch)
    torch.save(model.state_dict(), 'resnet18-circle25-50.pkl')


    print("loaded from ckpt!")
    test_acc = evaluate(model, test_loader)
    print("test acc:", test_acc)


if __name__ == "__main__":
    main()

使用 visdom 进行可视化,完成物体的识别.

  • 3
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值