Pytorch训练自定义数据集

有点乱,先放在这里,主要做个记录。

特征提取

参考了两种网络结构(一开始用的是VGG,以为用的网络有问题,又加了一个)

VGG13

VGG13网络,对最后的全连接层改小了一点。

import torch.nn as nn

class MyVGG13(nn.Module):
    def __init__(self,numclasses=200):
        super(MyVGG13,self).__init__()

        self.features=nn.Sequential(
            nn.Conv2d(3,64,kernel_size=3,stride=1,padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64,64,3,1,1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2,stride=2),

            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, 3, 1, 1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, 3, 1, 1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, 3, 1, 1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, 3, 1, 1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        self.classifer=nn.Sequential(
            nn.Linear(512*7*7,1024),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(1024,1024),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(1024,numclasses),
        )

    def forward(self,x):
        x=self.features(x)
        x=x.view(x.size(0),-1)
        x=self.classifer(x)
        return x

AlexNet

AlexNet,输入是224,原来是227,其他都差不多。

import torch.nn as nn
import torch
from torchsummary import summary

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
class MyAlexNet(nn.Module):
    def __init__(self,numclasses=200):
        super(MyAlexNet,self).__init__()

        self.features=nn.Sequential(
            nn.Conv2d(3,96,kernel_size=11,stride=4,padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(96,256,5,2,1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3,stride=2),
            nn.Conv2d(256,384,3,1,1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384,384,3,1,1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384,256,3,1,1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.classifer=nn.Sequential(
            nn.Linear(256*2*2,2048),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(2048,2048),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(2048,numclasses),
        )

    def forward(self,x):
        x=self.features(x)
        x=x.view(x.size(0),-1)
        x=self.classifer(x)
        return x


整体

自己实现了一个网络,对一些知识有一定的理解,还是很好的。
关于learning_rate,算是有了一点点理解,真就是太大也不行,太小也不行(我开始设置的学习率,让我以为是对数据集处理有问题,不管训练多少次,出现了准确率一直是一个定值了,看了多次还是没找到问题。)

参考

调参
代码整体结构比较好,要学着改自己的网络。

import argparse
import os
import glob
import csv
import PIL
import visdom
import matplotlib.pyplot as plt
import torch
import time

import torch.nn as nn

from train import MyAlexNet
from myVGG13 import MyVGG13
from torch.utils.data import Dataset
from torchvision import transforms
from torchsummary import summary

torch.manual_seed(12345)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

class MyData(Dataset):
    def __init__(self,root,transform=None):
        super(MyData,self).__init__()
        self.root=root
        self.transform=transform
        self.name2label={}  # 映射
        for name in sorted(os.listdir(os.path.join(root))):
            if not os.path.isdir(os.path.join(root,name)):
                continue
            self.name2label[name]=len(self.name2label.keys())
            #{'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4}
        print(self.name2label)

        # image+label
        self.images,self.labels=self.load_csv('Image2Label.csv')


    def load_csv(self,fliename):
        if not os.path.exists(os.path.join(self.root,fliename)):
            images=[]
            for name in self.name2label.keys():
                images+=glob.glob(os.path.join(self.root,name,'*.jpg'))

            print(len(images),images)
            with open(os.path.join(self.root,fliename),mode='w',newline='') as f:
                writer=csv.writer(f)
                for img in images:#'D:\\Projects\\DeepLearning\\Dataset\\flower_photos\\val\\daisy\\00000.jpg'
                    name=img.split(os.sep)[-2]
                    label=self.name2label[name]
                    # 'D:\\Projects\\DeepLearning\\Dataset\\flower_photos\\val\\daisy\\00000.jpg', '0'
                    writer.writerow([img,label])
                print('writen into csv file:',fliename)

        images,labels=[],[]
        with open(os.path.join(self.root,fliename)) as f:
            reader=csv.reader(f)
            for row in reader:
                # 'D:\\Projects\\DeepLearning\\Dataset\\flower_photos\\val\\daisy\\00000.jpg', '0'
                img,label=row
                label=int(label)

                images.append(img)
                labels.append(label)

        assert len(images)==len(labels)
        return images,labels

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        #idx[0,len(images)]
        #img:'D:\\Projects\\DeepLearning\\Dataset\\flower_photos\\val\\daisy\\00000.jpg'
        #label:0
        input_image,input_label=self.images[idx],self.labels[idx]
        #路径->图像数据类型
        input_image=PIL.Image.open(input_image).convert('RGB')
        if self.transform:
            input_image=self.transform(input_image)
        return input_image,input_label

def evalute(net,Dataloader):
    correct=0
    total =len(Dataloader.dataset)
    with torch.no_grad():
        for x,y in Dataloader:
            x,y=x.to(device),y.to(device)
            output=net(x)
            pred=output.argmax(dim=1)
            correct+=torch.eq(pred,y).sum().float().item()

    return correct / total




def main():
    parser = argparse.ArgumentParser(description='训练参数')
    parser.add_argument('--batchsize', type=int, default=32, help='The number of batch_size')
    parser.add_argument('--epochs', type=int, default=10, help='The number of epochs')
    args = parser.parse_args()

    #viz=visdom.Visdom() #将一个窗口类实例化

    train_image_path = r'D:\Projects\DeepLearning\Dataset\flower_photos\train'
    val_image_path = r'D:\Projects\DeepLearning\Dataset\flower_photos\val'
    tf=transforms.Compose([
                            transforms.Resize((224,224)),
                            transforms.ToTensor(),
                            transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
                          ])
    train_data=MyData(train_image_path,tf) #sample=MyData(image_path,None)
    val_data=MyData(val_image_path,tf)
    #x,y=next(iter(sample))
    #viz.image(x,win='sample_x',opts=dict(title='sample_x'))

    train_loader=torch.utils.data.DataLoader(train_data,batch_size=args.batchsize,shuffle=True,num_workers=4)
    val_loader=torch.utils.data.DataLoader(val_data,batch_size=args.batchsize,shuffle=True,num_workers=4)

    net=MyVGG13(5).to(device)

    #测试输出尺寸大小,方法一:
    # input=torch.randn(1,3,224,224).to(device)
    # features=net.features(input)
    # print(features.shape)

    #方法二:
    summary(net,(3,224,224))

    #必须将图片大小提前调整为一样才可以显示
    # for x,y in train_loader:
    #     viz.images(x,nrow=5,win='batch',opts=dict(title='batch'))
    #     viz.text(str(y),win='label',opts=dict(title='batch-y'))
    #     time.sleep(10)

    optimizer=torch.optim.Adam(net.parameters(),lr=0.0004)

    criterion=nn.CrossEntropyLoss()

    best_acc,best_epoch=0,0

    for epoch in range(args.epochs):
        net.train()
        #t1=time.perf_counter()
        for index,(x,y) in enumerate(train_loader):
            x,y=x.to(device),y.to(device)
            outputs=net(x)
            loss=criterion(outputs,y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if index % 32 == 0:
                print('此时的epoch为:{}/{},训练的loss: {:.6f}'.format(epoch,args.epochs,loss.item()))
        # print('epoch time:')
        # print(time.perf_counter()-t1)
        if epoch % 1==0:
            train_acc=evalute(net,train_loader)
            val_acc=evalute(net,val_loader)

            print('epoch:{}[{}/{}]\ttrain_acc:{:.8f}'.format(epoch,epoch,args.epochs,val_acc))
            print('epoch:{}[{}/{}]\tval_acc:{:.8f}'.format(epoch, epoch, args.epochs, train_acc))

            if val_acc>best_acc:
                best_epoch=epoch
                best_acc=val_acc

                torch.save(net.state_dict(),'best.pth')

    print('best_acc:',best_acc,'best_epoch:',best_epoch)

    # net.eval()
    # acc=0.0
    # correct=0
    # total =len(val_loader.dataset)
    # with torch.no_grad():
    #     for x,y in val_loader:
    #         x,y=x.to(device),y.to(device)
    #         output=net(x)
    #         pred=output.argmax(dim=1)
    #         correct+=torch.eq(pred,y).sum().float().item()
    # acc=correct / total
    # print('acc_test',acc)


if __name__ =='__main__':
    main()

预测

参考

from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import torch

from train import MyAlexNet
from myVGG13 import MyVGG13

data_transform=transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])

Image_path=r'./4.jpg'
img=Image.open(Image_path)
plt.imshow(img)

img=data_transform(img)
img=torch.unsqueeze(img,dim=0)

net=MyVGG13(5)
net.load_state_dict(torch.load('best.pth'))
net.eval()
with torch.no_grad():
    output=torch.squeeze(net(img)) #压缩batch
    predict=torch.softmax(output,dim=0)
    predict_cla=torch.argmax(predict).numpy()
print(str(predict_cla),predict[predict_cla].item())
plt.show()

结果

epoch:8[8/10]	train_acc:0.66621067
epoch:8[8/10]	val_acc:0.69785641
此时的epoch为:9/10,训练的loss: 0.746190
此时的epoch为:9/10,训练的loss: 0.830095
此时的epoch为:9/10,训练的loss: 0.954925
epoch:9[9/10]	train_acc:0.64979480
epoch:9[9/10]	val_acc:0.70670296

在这里插入图片描述

发布了58 篇原创文章 · 获赞 12 · 访问量 1万+
展开阅读全文

没有更多推荐了,返回首页

©️2019 CSDN 皮肤主题: 书香水墨 设计师: CSDN官方博客

分享到微信朋友圈

×

扫一扫,手机浏览