深度学习-少量自定义数据集实现迁移学习

本文将预训练的resnet18网络,使用少量pokemon数据集实现迁移学习,在此过程中使用visdom进行数据集和训练过程的可视化。本文代码主要分为两部分:1.加载自定义数据集(数据预处理,给对应类定义标签);2.迁移学习。

import torch
import torch.nn as nn
from torch.nn import functional as F
import os,glob
import random,csv
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms
from PIL import Image

class Pokemon(Dataset):
    def __init__(self,root,resize,mode): 
        super(Pokemon,self).__init__()
        self.root = root
        self.resize = resize
        self.name2label = {}
        for name in sorted(os.listdir(os.path.join(root))):
            if not os.path.isdir(os.path.join(root,name)):
                continue
            self.name2label[name] = len(self.name2label.keys())
        print(self.name2label)
        #image,label
        self.images,self.labels = self.load_csv('images.csv')
        #数据集划分
        if mode == 'train':#60%
            self.images = self.images[:int(0.6*len(self.images))]
            self.labels = self.labels[:int(0.6*len(self.labels))]
        if mode == 'val':#60%-80%
            self.images = self.images[int(0.6*len(self.images)):int(0.8*len(self.images))]
            self.labels = self.labels[int(0.6*len(self.labels)):int(0.8*len(self.labels))]
        else:#80%-100%
            self.images = self.images[int(0.8*len(self.images)):]
            self.labels = self.labels[int(0.8*len(self.labels)):]


    def load_csv(self,filename):
        if not os.path.exists(os.path.join(self.root,filename)):
            images = []
            for name in self.name2label.keys():
                images += glob.glob(os.path.join(self.root,name,'*.png'))
                images += glob.glob(os.path.join(self.root,name,'*.jpg'))
                images += glob.glob(os.path.join(self.root,name,'*.jpeg'))

            print(len(images),images)
            random.shuffle(images)
            with open(os.path.join(self.root,filename),mode='w',newline='') as f:
                writer = csv.writer(f)
                for img in images: #pokemon\\bulbasaur\\00000000.png
                    name = img.split(os.sep)[-2]
                    label = self.name2label[name]
                    writer.writerow([img,label])
                print('writen into csv file:',filename)
            #read from csv file
        images , labels = [],[]
        with open(os.path.join(self.root,filename)) as f:
            reader = csv.reader(f)
            for row in reader:
                img,label = row
                label = int(label)
                images.append(img)
                labels.append(label)
        assert len(images) == len(labels)
        return images,labels
        
        
    def __len__(self):
        return len(self.images)
    
    
    def denormalize(self,x_hat):
        mean = [0.485,0.456,0.406]
        std = [0.229,0.224,0.225]
        #x_hat = (x-mean)/std
        #x = x_hat*std+mean
        #x:[c,h,w]
        #mean:[3]=>[3,1,1] broadcasting
        mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
        std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
        x = x_hat * std + mean 
        return x
        
        
    def __getitem__(self,idx):
        img,label = self.images[idx],self.labels[idx]
        tf = transforms.Compose([
            lambda x:Image.open(x).convert('RGB'),#string path->image data
            transforms.Resize((int(self.resize*1.25),int(self.resize*1.25))),
            transforms.RandomRotation(15),
            transforms.CenterCrop(self.resize),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])
        ])
        img = tf(img)
        label = torch.tensor(label)
        return img,label
        
    
def main():
    import visdom
    import time
    viz = visdom.Visdom()
    db = Pokemon('pokemon',224,'train')
    x,y = next(iter(db))
    print('sample',x.shape,y.shape,y)
    viz.image(db.denormalize(x),win='sample_x',opts=dict(title='sample_x'))
    loader = DataLoader(db,batch_size=32,shuffle=True,num_workers=0)
    
    for x,y in loader:
        viz.images(db.denormalize(x),nrow=8,win='batch',opts=dict(title='batch'))
        viz.text(str(y.numpy()),win='label',opts=dict(title='batch-y'))
        time.sleep(10)
    
if __name__ =='__main__':
    main()
#由于pytorch没有Flatten功能,因此先手写一个Flatten层
class Flatten(nn.Module):
    def __init__(self):
        super(Flatten,self).__init__()
        
        
    def forward(self,x):
        shape = torch.prod(torch.tensor(x.shape[1:])).item()
        return x.view(-1,shape)
#迁移学习
from torch import optim,nn
import visdom
import torchvision
from torch.utils.data import DataLoader
from torchvision.models import resnet18

batchsz = 32
lr = 1e-3
epochs = 10
device = torch.device('cpu')
torch.manual_seed(1234)
train_db = Pokemon('pokemon',224,'train')
val_db = Pokemon('pokemon',224,'val')
test_db = Pokemon('pokemon',224,'test')
train_loader = DataLoader(train_db,batch_size=batchsz,shuffle=True)
val_loader = DataLoader(val_db,batch_size=batchsz)
test_loader = DataLoader(test_db,batch_size=batchsz)
viz = visdom.Visdom()


def evaluate(model,loader):
    correct = 0
    total = len(loader.dataset)
    for x,y in loader:
        x,y = x.to(device),y.to(device)
        with torch.no_grad():
            logits = model(x)
            pred = logits.argmax(dim=1)
        correct += torch.eq(pred,y).sum().float().item()
    return correct/total


def main():
    #model = ResNet18(5).to(device)
    trained_model = resnet18(pretrained=True)
    model = nn.Sequential(*list(trained_model.children())[:-1],
                          Flatten(),
                          nn.Linear(512,5)
                         ).to(device)
    
    #x = torch.randn(2,3,224,224)
    #print(model(x).shape)
    optimizer = optim.Adam(model.parameters(),lr=lr)
    criterion = nn.CrossEntropyLoss()
    best_acc,best_epoch=0,0
    global_step = 0
    viz.line([0],[-1],win='loss',opts=dict(title='loss'))
    viz.line([0],[-1],win='val_acc',opts=dict(title='val_acc'))
    for epoch in range(epochs):
        for step,(x,y) in enumerate(train_loader):
            #x:[b,3,224,224],y[b]
            x,y = x.to(device),y.to(device)
            logits = model(x)
            loss = criterion(logits,y) #会自动做one-hot
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            viz.line([loss.item()],[global_step],win='loss',update='append')
            global_step += 1
        if epoch %1 ==0:
            val_acc = evaluate(model,val_loader)
            if val_acc > best_acc:
                best_epoch = epoch
                best_acc = val_acc
                torch.save(model.state_dict(),'best.mdl')
                viz.line([val_acc],[global_step],win='val_acc',update='append')
    print('best acc:',best_acc,'best epoch:',best_epoch)
    model.load_state_dict(torch.load('best.mdl'))
    print('loaded from ckpt!')
    test_acc = evalute(model,test_loader)
    print('test_acc:',test_acc)
    
    
if __name__  == '__main__':
    main()
	
		
  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值