Pytorch学习笔记7——自定义数据集

Pytorch学习笔记7——自定义数据集

在这里插入图片描述
在这里插入图片描述

1.读取数据

首先继承自torch.utils.data.Dataset
重写len与getitem
在这里插入图片描述
train就用train数据集,test就用test数据集。

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

自定义数据集的读取

import torch
import os,glob
import random,csv

import torchvision.datasets
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms
from PIL import Image
from torchvision.transforms import InterpolationMode


class Pokemon(Dataset):
    def __init__(self,root,resize,mode):
        super(Pokemon, self).__init__()

        self.root=root
        self.resize=resize

        self.name2label={}#字典表达映射关系label‘‘sq..’’:0
        for name in sorted(os.listdir(os.path.join(root))):#遍历根目录下所有文件假
            if not os.path.isdir(os.path.join(root,name)):#判断是否是文件夹
                continue
            self.name2label[name]=len(self.name2label.keys())
        print(self.name2label)
        self.images,self.labels=self.load_csv('images.csv')#得到的是images的路径,和对应的数字标签

        if mode=='train':
            self.images=self.images[:int(0.6*len(self.images))]
            self.labels=self.labels[:int(0.6*len(self.labels))]
        elif mode=='val':#20
            self.images=self.images[int(0.6*len(self.images)):int(0.8*len(self.images))]
            self.labels=self.labels[int(0.6*len(self.labels)):int(0.8*len(self.labels))]
        else:
            self.images=self.images[int(0.8*len(self.images)):]
            self.labels=self.labels[int(0.8*len(self.labels)):]

#创建数据对:path+label
    def load_csv(self,filename):
        if not os.path.exists(os.path.join(self.root,filename)):#如果已经有了,不需要再创建
            images=[]
            for name in self.name2label.keys():#key:value
                images+=glob.glob(os.path.join(self.root,name,'*.png'))#glob方法获取目录下所有满足的文件
                images += glob.glob(os.path.join(self.root, name, '*.jpg'))
                images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
            #1165,pokeman/bulbasaur/00001.png

            #对应关系保存到csv
            random.shuffle(images)
            with open(os.path.join(self.root,filename),mode='w',newline='') as f:
                writer=csv.writer(f)
                for img in images:#pokeman/bulbasaur/00001.png
                    name=img.split(os.sep)[-2]
                    label=self.name2label[name]#字典根据key找value存入label
                    writer.writerow([img,label])#pokeman/bulbasaur/00001.png,0
                print('writen into csvfile:',filename)
        #read from csv
        images,labels=[],[]
        with open(os.path.join(self.root,filename)) as f:
            reader=csv.reader(f)
            for row in reader:
                img,label=row
                label=int(label)
                images.append(img)
                labels.append(label)
        assert len(images)==len(labels)
        return images,labels
    def __len__(self):
        return len(self.images)

    def denormalize(self,x_hat):#逆归一化已回复图片视觉效果
        mean = [0.845, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]

        #x_hat=(x-mean)/std
        #x=x_hat*std+mean
        #x:[c,h,w]
        #mean:[3]=>[3,1,1]
        mean=torch.tensor(mean).unsqueeze(1).unsqueeze(1)
        std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
        print(mean.shape,std.shape)
        x=x_hat*std+mean
        return x

    def __getitem__(self, idx):
        #self.images,self.labels
        #idx-[0-len(images)]
        img,label=self.images[idx],self.labels[idx]#从csv获得的图片路径与label
        tf=transforms.Compose([
            lambda x:Image.open(x).convert('RGB'),#string path=>image data
            transforms.Resize((int(self.resize*1.25),int(self.resize*1.25))),
            transforms.RandomRotation(15),
            transforms.CenterCrop(self.resize),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.845,0.456,0.406],
                                 std=[0.229,0.224,0.225])
        ])


        img=tf(img)
        label=torch.tensor(label)
        return img,label
if __name__=='__main__':
    import visdom
    import time
    viz=visdom.Visdom()
    # tf = transforms.Compose([
    #     transforms.Resize((64,64)),
    #     transforms.ToTensor(),
    # ])
    # db=torchvision.datasets.ImageFolder(root='/home/lizheng/Study/yolov5-5.0/pytorch1/pokemon/pokeman',transform=tf)
    # loader=DataLoader(db,batch_size=32,shuffle=True)#一行代码完成数据集加载工作
    # print(db.class_to_idx)
    # for x,y in loader:
    #     viz.images(x,nrow=8,win='batch',opts=dict(title='batch'))
    #     viz.text(str(y.numpy()),win='label',opts=dict(title='batch-y'))
    #
    #     time.sleep(10)
    db=Pokemon('pokemon/pokeman',64,'train')

    x,y=next(iter(db))#利用迭代器输入路径获得具体图像,得到第一个样本,调用时自动使用getitem函数,此时x是图像
    print('sample:',x.shape,y.shape,y)

    viz.image(db.denormalize(x),win='sample_x',opts=dict(title='sample_x'))
    loader=DataLoader(db,batch_size=32,shuffle=True,num_workers=8)#不想一个一个取,想一个batch一个batch取
    for x,y in loader:
        viz.images(db.denormalize(x),nrow=8,win='batch',opts=dict(title='batch'))
        viz.text(str(y.numpy()),win='label',opts=dict(title='batch-y'))

        time.sleep(10)


自定义神经网络模型的搭建

import torch
from torch import nn
from torch.nn import functional as F
class ResBlk(nn.Module):
    def __init__(self,ch_in,ch_out,stride=1):
        '''
        :param ch_in:
        :param ch_out:
        '''
        super(ResBlk,self).__init__()#super方法避免父类的init函数被替换
        self.conv1=nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=stride,padding=1)
        self.bn1=nn.BatchNorm2d(ch_out)
        self.conv2=nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1)
        self.bn2=nn.BatchNorm2d(ch_out)

        self.extra=nn.Sequential()

        if ch_out!=ch_in:
            self.extra=nn.Sequential(
                nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=stride),#Sequential里面加入的都是类,因此可以自己写,然后加入
                nn.BatchNorm2d(ch_out)#这些类在调用时会自动调用forward函数,记得要写return
            )


    def forward(self,x):
        '''
        :param x:[b,ch,h,w]
        :return:
        '''
        out=F.relu(self.bn1(self.conv1(x)))
        out=self.bn2(self.conv2(out))
        #short cut
        #extra module:[b,ch_in,h,w] with [b,ch_out,h,w]
        #element-wise add:\
        out=self.extra(x)+out
        return out

class ResNet18(nn.Module):
    def __init__(self,num_class):
        super(ResNet18,self).__init__()
        self.conv1=nn.Sequential(
            nn.Conv2d(3,16,kernel_size=3,stride=3,padding=0),
            nn.BatchNorm2d(16)
        )
        #followed 4 blocks
        #[b,16,h,w]=>[b,32,h,w]
        self.blk1=ResBlk(16,32,stride=3)#增多通道,减少长宽,避免数据量过大
        #[b,32,h,w]=>[b,64,h,w]
        self.blk2 = ResBlk(32, 64,stride=3)
        #[b,64,h,w]=>[b,128,h,w]
        self.blk3 = ResBlk(64,128,stride=2)
        # [b,128,h,w]=>[b,256,h,w]
        self.blk4 = ResBlk(128,256,stride=2)
        #[b,256,7,7]
        self.outlayer=nn.Linear(256*3*3,num_class)#输入512通道,输出10通道
    def forward(self,x):
        '''
        :param x:
        :return:
        '''
        x=F.relu(self.conv1(x))
        #[b,64,h,w]=>[b,1024,h,w]
        x=self.blk1(x)
        x=self.blk2(x)
        x=self.blk3(x)
        x=self.blk4(x)
       # print('after conv:',x.shape)#[b,512,2,2]
     #    # [b,512,h,w]=>[b,512,2,2]
     #    x=F.adaptive_avg_pool2d(x,[1,1])
     # #   print('after pool:',x.shape)
        x=x.view(x.size(0),-1)
        x=self.outlayer(x)
        return x

if __name__=='__main__':
    blk=ResBlk(64,128)
    tmp=torch.randn(2,64,224,224)
    out=blk(tmp)
    print('block',out.shape)

    model=ResNet18(5)#5分类
    tmp=torch.randn(2,3,224,224)
    out=model(tmp)
    print('resnet:',out.shape)

    p=sum(map(lambda p:p.numel(),model.parameters()))
    print('parameters size:',p)

自定义数据集的训练与测试:

import torch
from torch import optim,nn
import visdom
import torchvision
from torch.utils.data import DataLoader

from pytorch1.pt3 import Pokemon
from resnet import ResNet18

batchsz=32
lr=1e-3
epochs=10

device=torch.device('cuda')
torch.manual_seed(1234)

train_db=Pokemon('pokemon/pokeman',224,mode='train')#initial函数初始化训练集,
val_db=Pokemon('pokemon/pokeman',224,mode='val')
test_db=Pokemon('pokemon/pokeman',224,mode='test')

train_loader=DataLoader(train_db,batch_size=batchsz,shuffle=True,num_workers=4)#loader里获得的都是真正是图片
val_loader=DataLoader(val_db,batch_size=batchsz,num_workers=4)
test_loader=DataLoader(test_db,batch_size=batchsz,num_workers=4)

viz = visdom.Visdom()

def evaluate(model,loader):
    correct=0
    total=len(loader.dataset)
    for x,y in loader:
        x,y=x.to(device),y.to(device)
        with torch.no_grad():
            logits=model(x)
            pred=logits.argmax(dim=1)
        correct+=torch.eq(pred,y).sum().float().item()

    return correct/total



if __name__=='__main__':
    model=ResNet18(5).to(device)
    optimizer=optim.Adam(model.parameters(),lr=lr)
    criteon=nn.CrossEntropyLoss()
    best_acc,best_epoch=0,0
    global_step=0
    viz.line([0],[-1],win='loss',opts=dict(title='loss'))
    viz.line([0],[-1],win='val_acc',opts=dict(title='val_acc'))#清空操作

    for epoch in range(epochs):
        for step,(x,y) in enumerate(train_loader):
            #x:[b,3,224,224],y:[b]
            x,y=x.to(device),y.to(device)
            logits=model(x)
            loss=criteon(logits,y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            viz.line([loss.item()], [global_step], win='loss', update='append')
            global_step+=1
        if epoch%2==0:
            val_acc=evaluate(model,val_loader)
            if val_acc>best_acc:
                best_epoch=epoch
                best_acc=val_acc

                torch.save(model.state_dict(),'best.mdl')
                viz.line([val_acc], [global_step], win='val_acc',update='append')

    print('best acc:',best_acc,'best_epoch:',best_epoch)
    model.load_state_dict(torch.load('best.mdl'))
    print('loaded from ckpt!')

    test_acc=evaluate(model,test_loader)
    print('test acc:',test_acc)

实验效果:

在这里插入图片描述

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值