pytorch用自己图片训练/测试/保存

目录

1、训练和测试代码

2、遇到的问题

3、结果展示

参考


1、训练和测试代码

使用时,将训练和测试代码路径修改,并将输出类别修改成需要的类别即可。我这里为6分类。

数据集准备格式:

train,test下分别有6个文件夹:0 1 2 3 4 5。文件夹名为类别名。

 

代码如下:

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
plt.switch_backend('agg')

def loadtraindata():
    #path = r"/mnt/nas/cv_data/imagequality/waterloo_de20_all/train"
    path = r"/mnt/nas/cv_data/imagequality/testiq/train"
    trainset = torchvision.datasets.ImageFolder(path,
                                                transform=transforms.Compose(
                                                    [transforms.Resize((32, 32)),
                                                     transforms.CenterCrop(32),
                                                     transforms.ToTensor()])
                                                )
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)

    return trainloader

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 6)

    def forward(self, x):

        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)

        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

classes = ('0','1','2','3','4','5')

def loadtestdata():
    #path = r"/mnt/nas/cv_data/imagequality/waterloo_de20_all/test"
    path = r"/mnt/nas/cv_data/imagequality/testiq/test"
    testset = torchvision.datasets.ImageFolder(path,
                                               transform=transforms.Compose([
                                                   transforms.Resize((32, 32)),
                                                   transforms.ToTensor()])
                                               )
    testloader = torch.utils.data.DataLoader(testset, batch_size=25,shuffle=True, num_workers=2)
    return testloader

def trainandsave():
    trainloader = loadtraindata()
    net = Net()
    optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    criterion = nn.CrossEntropyLoss()
    # train
    for epoch in range(5):
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            # get the inputs
            inputs, labels = data

            # wrap them in Variable
            inputs, labels = Variable(inputs), Variable(labels)
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            #running_loss += loss.data[0]
            running_loss += loss.item()

            if i % 200 == 199:
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, i + 1, running_loss / 200))
                running_loss = 0.0

    print('Finished Training')
    torch.save(net, 'net.pkl')
    torch.save(net.state_dict(), 'net_params.pkl')

def reload_net():
    trainednet = torch.load('net.pkl')
    return trainednet

def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

def test():
    testloader = loadtestdata()
    net = reload_net()
    dataiter = iter(testloader)
    images, labels = dataiter.next()  #
    imshow(torchvision.utils.make_grid(images, nrow=5))
    print('GroundTruth: ', " ".join('%5s' % classes[labels[j]] for j in range(2)))
    outputs = net(Variable(images))
    _, predicted = torch.max(outputs.data, 1)

    print('Predicted: ', " ".join('%5s' % classes[predicted[j]] for j in range(2)))

trainandsave()
test()

 

2、遇到的问题

(1)问题1

raise notImplementedError

解决:

.py文件中空格/缩放有问题

 

(2)问题2

AttributeError: 'NoneType' object has no attribute 'log_softmax'

解决:

检查代码发现,forward没有return

 

(3)问题3
IndexError: invalid index of a 0-dim tensor. Use tensor.item() to convert a 0-dim tensor to a Python number

解决:应该是版本升级的问题

将loss.data[0] 改为loss.item()


(4)问题4

raise RuntimeError('Invalid DISPLAY variable')

RuntimeError: Invalid DISPLAY variable

解决:

matplotlib的默认backend是TkAgg,而FltkAgg, GTK, GTKAgg, GTKCairo, TkAgg , Wx or WxAgg这几个backend都要求有GUI图形界面的,我运行的linux环境是没有图形界面的,所以报错。

改成:指定不需要GUI的backend(Agg, Cairo, PS, PDF or SVG)

import matplotlib.pyplot as plt
plt.switch_backend('agg')

参考:https://www.cnblogs.com/bymo/p/7447409.html

 

(5)问题5

print('GroundTruth: ', " ".join('%5s' % classes[labels[j]] for j in range(25)))

IndexError: index 12 is out of bounds for dimension 0 with size 12

解决:

测试展示的index越界,range()的入参要小于测试图片个数。

 

3、结果展示

测试了1张图片,预测正确。

 

生成的模型文件如下:

 

参考

代码:https://blog.csdn.net/a738833592/article/details/80900250

报错:https://blog.csdn.net/terry_zeng/article/details/25985419

报错:https://www.cnblogs.com/bymo/p/7447409.html

报错:https://blog.csdn.net/LYKXHTP/article/details/81565453

  • 10
    点赞
  • 59
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
好的,我可以为您提供相关的代码。 首先,您需要安装PyTorch和torchvision库。如果您使用Anaconda,可以通过以下命令安装: ``` conda install pytorch torchvision -c pytorch ``` 接下来,您可以使用以下代码来训练一个VGG16模型: ```python import torch import torch.nn as nn import torch.optim as optim import torchvision import torchvision.transforms as transforms # 定义数据预处理 transform = transforms.Compose( [transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) # 加载训练数据 trainset = torchvision.datasets.ImageFolder(root='./data/train', transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2) # 加载测试数据 testset = torchvision.datasets.ImageFolder(root='./data/test', transform=transform) testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2) # 定义模型 vgg16 = torchvision.models.vgg16(pretrained=True) num_features = vgg16.classifier[6].in_features vgg16.classifier[6] = nn.Linear(num_features, 2) # 定义损失函数和优化器 criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(vgg16.parameters(), lr=0.001, momentum=0.9) # 训练模型 for epoch in range(10): running_loss = 0.0 for i, data in enumerate(trainloader, 0): inputs, labels = data optimizer.zero_grad() outputs = vgg16(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() if i % 2000 == 1999: print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000)) running_loss = 0.0 print('Finished Training') # 保存模型 PATH = './vgg16.pth' torch.save(vgg16.state_dict(), PATH) ``` 在上述代码中,我们使用ImageFolder加载了训练测试数据,并使用预定义的VGG16模型。我们还定义了损失函数和优化器,并使用SGD算法进行训练。最后,我们将训练好的模型保存在vgg16.pth文件中。 接下来,我们可以使用以下代码来测试训练好的模型: ```python import torch import torchvision.transforms as transforms from PIL import Image # 加载模型 PATH = './vgg16.pth' vgg16 = torchvision.models.vgg16(pretrained=True) num_features = vgg16.classifier[6].in_features vgg16.classifier[6] = nn.Linear(num_features, 2) vgg16.load_state_dict(torch.load(PATH)) vgg16.eval() # 加载测试图片并进行预处理 transform = transforms.Compose( [transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) image = Image.open('./data/test/cat/cat1.jpg') image = transform(image).unsqueeze(0) # 使用模型进行预测 output = vgg16(image) prediction = torch.argmax(output, dim=1) # 输出预测结果 if prediction == 0: print('This is a cat.') else: print('This is a dog.') ``` 在上述代码中,我们首先加载训练好的模型,然后加载一张测试图片并进行预处理。接下来,我们使用模型进行预测,并输出预测结果。 希望这些代码能够帮助到您。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值