Pytorch训练自定义数据集

有点乱,先放在这里,主要做个记录。

Pytorch训练自定义数据集

特征提取

参考了两种网络结构(一开始用的是VGG,以为用的网络有问题,又加了一个)

VGG13

VGG13网络,对最后的全连接层改小了一点。

import torch.nn as nn

class MyVGG13(nn.Module):
    def __init__(self,numclasses=200):
        super(MyVGG13,self).__init__()

        self.features=nn.Sequential(
            nn.Conv2d(3,64,kernel_size=3,stride=1,padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64,64,3,1,1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2,stride=2),

            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, 3, 1, 1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, 3, 1, 1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, 3, 1, 1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, 3, 1, 1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        self.classifer=nn.Sequential(
            nn.Linear(512*7*7,1024),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(1024,1024),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(1024,numclasses),
        )

    def forward(self,x):
        x=self.features(x)
        x=x.view(x.size(0),-1)
        x=self.classifer(x)
        return x
AlexNet

AlexNet,输入是224,原来是227,其他都差不多。

import torch.nn as nn
import torch
from torchsummary import summary

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
class MyAlexNet(nn.Module):
    def __init__(self,numclasses=200):
        super(MyAlexNet,self).__init__()

        self.features=nn.Sequential(
            nn.Conv2d(3,96,kernel_size=11,stride=4,padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(96,256,5,2,1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3,stride=2),
            nn.Conv2d(256,384,3,1,1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384,384,3,1,1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384,256,3,1,1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.classifer=nn.Sequential(
            nn.Linear(256*2*2,2048),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(2048,2048),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(2048,numclasses),
        )

    def forward(self,x):
        x=self.features(x)
        x=x.view(x.size(0),-1)
        x=self.classifer(x)
        return x


整体

自己实现了一个网络,对一些知识有一定的理解,还是很好的。
关于learning_rate,算是有了一点点理解,真就是太大也不行,太小也不行(我开始设置的学习率,让我以为是对数据集处理有问题,不管训练多少次,出现了准确率一直是一个定值了,看了多次还是没找到问题。)

参考

调参
代码整体结构比较好,要学着改自己的网络。

import argparse
import os
import glob
import csv
import PIL
import visdom
import matplotlib.pyplot as plt
import torch
import time

import torch.nn as nn

from train import MyAlexNet
from myVGG13 import MyVGG13
from torch.utils.data import Dataset
from torchvision import transforms
from torchsummary import summary

torch.manual_seed(12345)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

class MyData(Dataset):
    def __init__(self,root,transform=None):
        super(MyData,self).__init__()
        self.root=root
        self.transform=transform
        self.name2label={}  # 映射
        for name in sorted(os.listdir(os.path.join(root))):
            if not os.path.isdir(os.path.join(root,name)):
                continue
            self.name2label[name]=len(self.name2label.keys())
            #{'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4}
        print(self.name2label)

        # image+label
        self.images,self.labels=self.load_csv('Image2Label.csv')


    def load_csv(self,fliename):
        if not os.path.exists(os.path.join(self.root,fliename)):
            images=[]
            for name in self.name2label.keys():
                images+=glob.glob(os.path.join(self.root,name,'*.jpg'))

            print(len(images),images)
            with open(os.path.join(self.root,fliename),mode='w',newline='') as f:
                writer=csv.writer(f)
                for img in images:#'D:\\Projects\\DeepLearning\\Dataset\\flower_photos\\val\\daisy\\00000.jpg'
                    name=img.split(os.sep)[-2]
                    label=self.name2label[name]
                    # 'D:\\Projects\\DeepLearning\\Dataset\\flower_photos\\val\\daisy\\00000.jpg', '0'
                    writer.writerow([img,label])
                print('writen into csv file:',fliename)

        images,labels=[],[]
        with open(os.path.join(self.root,fliename)) as f:
            reader=csv.reader(f)
            for row in reader:
                # 'D:\\Projects\\DeepLearning\\Dataset\\flower_photos\\val\\daisy\\00000.jpg', '0'
                img,label=row
                label=int(label)

                images.append(img)
                labels.append(label)

        assert len(images)==len(labels)
        return images,labels

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        #idx[0,len(images)]
        #img:'D:\\Projects\\DeepLearning\\Dataset\\flower_photos\\val\\daisy\\00000.jpg'
        #label:0
        input_image,input_label=self.images[idx],self.labels[idx]
        #路径->图像数据类型
        input_image=PIL.Image.open(input_image).convert('RGB')
        if self.transform:
            input_image=self.transform(input_image)
        return input_image,input_label

def evalute(net,Dataloader):
    correct=0
    total =len(Dataloader.dataset)
    with torch.no_grad():
        for x,y in Dataloader:
            x,y=x.to(device),y.to(device)
            output=net(x)
            pred=output.argmax(dim=1)
            correct+=torch.eq(pred,y).sum().float().item()

    return correct / total




def main():
    parser = argparse.ArgumentParser(description='训练参数')
    parser.add_argument('--batchsize', type=int, default=32, help='The number of batch_size')
    parser.add_argument('--epochs', type=int, default=10, help='The number of epochs')
    args = parser.parse_args()

    #viz=visdom.Visdom() #将一个窗口类实例化

    train_image_path = r'D:\Projects\DeepLearning\Dataset\flower_photos\train'
    val_image_path = r'D:\Projects\DeepLearning\Dataset\flower_photos\val'
    tf=transforms.Compose([
                            transforms.Resize((224,224)),
                            transforms.ToTensor(),
                            transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
                          ])
    train_data=MyData(train_image_path,tf) #sample=MyData(image_path,None)
    val_data=MyData(val_image_path,tf)
    #x,y=next(iter(sample))
    #viz.image(x,win='sample_x',opts=dict(title='sample_x'))

    train_loader=torch.utils.data.DataLoader(train_data,batch_size=args.batchsize,shuffle=True,num_workers=4)
    val_loader=torch.utils.data.DataLoader(val_data,batch_size=args.batchsize,shuffle=True,num_workers=4)

    net=MyVGG13(5).to(device)

    #测试输出尺寸大小,方法一:
    # input=torch.randn(1,3,224,224).to(device)
    # features=net.features(input)
    # print(features.shape)

    #方法二:
    summary(net,(3,224,224))

    #必须将图片大小提前调整为一样才可以显示
    # for x,y in train_loader:
    #     viz.images(x,nrow=5,win='batch',opts=dict(title='batch'))
    #     viz.text(str(y),win='label',opts=dict(title='batch-y'))
    #     time.sleep(10)

    optimizer=torch.optim.Adam(net.parameters(),lr=0.0004)

    criterion=nn.CrossEntropyLoss()

    best_acc,best_epoch=0,0

    for epoch in range(args.epochs):
        net.train()
        #t1=time.perf_counter()
        for index,(x,y) in enumerate(train_loader):
            x,y=x.to(device),y.to(device)
            outputs=net(x)
            loss=criterion(outputs,y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if index % 32 == 0:
                print('此时的epoch为:{}/{},训练的loss: {:.6f}'.format(epoch,args.epochs,loss.item()))
        # print('epoch time:')
        # print(time.perf_counter()-t1)
        if epoch % 1==0:
            train_acc=evalute(net,train_loader)
            val_acc=evalute(net,val_loader)

            print('epoch:{}[{}/{}]\ttrain_acc:{:.8f}'.format(epoch,epoch,args.epochs,val_acc))
            print('epoch:{}[{}/{}]\tval_acc:{:.8f}'.format(epoch, epoch, args.epochs, train_acc))

            if val_acc>best_acc:
                best_epoch=epoch
                best_acc=val_acc

                torch.save(net.state_dict(),'best.pth')

    print('best_acc:',best_acc,'best_epoch:',best_epoch)

    # net.eval()
    # acc=0.0
    # correct=0
    # total =len(val_loader.dataset)
    # with torch.no_grad():
    #     for x,y in val_loader:
    #         x,y=x.to(device),y.to(device)
    #         output=net(x)
    #         pred=output.argmax(dim=1)
    #         correct+=torch.eq(pred,y).sum().float().item()
    # acc=correct / total
    # print('acc_test',acc)


if __name__ =='__main__':
    main()

预测

参考

from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import torch

from train import MyAlexNet
from myVGG13 import MyVGG13

data_transform=transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])

Image_path=r'./4.jpg'
img=Image.open(Image_path)
plt.imshow(img)

img=data_transform(img)
img=torch.unsqueeze(img,dim=0)

net=MyVGG13(5)
net.load_state_dict(torch.load('best.pth'))
net.eval()
with torch.no_grad():
    output=torch.squeeze(net(img)) #压缩batch
    predict=torch.softmax(output,dim=0)
    predict_cla=torch.argmax(predict).numpy()
print(str(predict_cla),predict[predict_cla].item())
plt.show()

结果

epoch:8[8/10]	train_acc:0.66621067
epoch:8[8/10]	val_acc:0.69785641
此时的epoch为:9/10,训练的loss: 0.746190
此时的epoch为:9/10,训练的loss: 0.830095
此时的epoch为:9/10,训练的loss: 0.954925
epoch:9[9/10]	train_acc:0.64979480
epoch:9[9/10]	val_acc:0.70670296

在这里插入图片描述

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值