2024年Python最全【图像分类】实战——使用VGG16实现对植物幼苗的分类(pytroch

transform_test = transforms.Compose([

transforms.Resize((224, 224)),

transforms.ToTensor(),

transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])

])

读取数据

====

将数据集解压后放到data文件夹下面,如图:

然后我们在dataset文件夹下面新建 __init__.py和dataset.py,在dataset.py文件夹写入下面的代码:

说一下代码的核心逻辑。

第一步 建立字典,定义类别对应的ID,用数字代替类别。

第二步 在__init__里面编写获取图片路径的方法。测试集只有一层路径直接读取,训练集在train文件夹下面是类别文件夹,先获取到类别,再获取到具体的图片路径。然后使用sklearn中切分数据集的方法,按照7:3的比例切分训练集和验证集。

第三步 在__getitem__方法中定义读取单个图片和类别的方法,由于图像中有位深度32位的,所以我在读取图像的时候做了转换。

coding:utf8

import os

from PIL import Image

from torch.utils import data

from torchvision import transforms as T

from sklearn.model_selection import train_test_split

Labels = {‘Black-grass’: 0, ‘Charlock’: 1, ‘Cleavers’: 2, ‘Common Chickweed’: 3,

‘Common wheat’: 4, ‘Fat Hen’: 5, ‘Loose Silky-bent’: 6, ‘Maize’: 7, ‘Scentless Mayweed’: 8,

‘Shepherds Purse’: 9, ‘Small-flowered Cranesbill’: 10, ‘Sugar beet’: 11}

class SeedlingData (data.Dataset):

def init(self, root, transforms=None, train=True, test=False):

“”"

主要目标: 获取所有图片的地址,并根据训练,验证,测试划分数据

“”"

self.test = test

self.transforms = transforms

if self.test:

imgs = [os.path.join(root, img) for img in os.listdir(root)]

self.imgs = imgs

else:

imgs_labels = [os.path.join(root, img) for img in os.listdir(root)]

imgs = []

for imglable in imgs_labels:

for imgname in os.listdir(imglable):

imgpath = os.path.join(imglable, imgname)

imgs.append(imgpath)

trainval_files, val_files = train_test_split(imgs, test_size=0.3, random_state=42)

if train:

self.imgs = trainval_files

else:

self.imgs = val_files

def getitem(self, index):

“”"

一次返回一张图片的数据

“”"

img_path = self.imgs[index]

img_path=img_path.replace(“\”,‘/’)

if self.test:

label = -1

else:

labelname = img_path.split(‘/’)[-2]

label = Labels[labelname]

data = Image.open(img_path).convert(‘RGB’)

data = self.transforms(data)

return data, label

def len(self):

return len(self.imgs)

然后我们在train.py调用SeedlingData读取数据 ,记着导入刚才写的dataset.py(from dataset.dataset import SeedlingData)

dataset_train = SeedlingData(‘data/train’, transforms=transform, train=True)

dataset_test = SeedlingData(“data/train”, transforms=transform_test, train=False)

读取数据

print(dataset_train.imgs)

导入数据

train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False)

设置模型

====

使用CrossEntropyLoss作为loss,模型采用alexnet,选用预训练模型。更改全连接层,将最后一层类别设置为12,然后将模型放到DEVICE。优化器选用Adam。

实例化模型并且移动到GPU

criterion = nn.CrossEntropyLoss()

model_ft = vgg16(pretrained=True)

model_ft.classifier = classifier = nn.Sequential(

nn.Linear(512 * 7 * 7, 4096),

nn.ReLU(True),

nn.Dropout(),

nn.Linear(4096, 4096),

nn.ReLU(True),

nn.Dropout(),

nn.Linear(4096, 12),

)

model_ft.to(DEVICE)

选择简单暴力的Adam优化器,学习率调低

optimizer = optim.Adam(model_ft.parameters(), lr=modellr)

def adjust_learning_rate(optimizer, epoch):

“”“Sets the learning rate to the initial LR decayed by 10 every 30 epochs”“”

modellrnew = modellr * (0.1 ** (epoch // 50))

print(“lr:”, modellrnew)

for param_group in optimizer.param_groups:

param_group[‘lr’] = modellrnew

设置训练和验证

=======

定义训练过程

def train(model, device, train_loader, optimizer, epoch):

model.train()

sum_loss = 0

total_num = len(train_loader.dataset)

print(total_num, len(train_loader))

for batch_idx, (data, target) in enumerate(train_loader):

data, target = Variable(data).to(device), Variable(target).to(device)

output = model(data)

loss = criterion(output, target)

optimizer.zero_grad()

loss.backward()

optimizer.step()

print_loss = loss.data.item()

sum_loss += print_loss

if (batch_idx + 1) % 10 == 0:

print(‘Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}’.format(

epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),

    • (batch_idx + 1) / len(train_loader), loss.item()))

ave_loss = sum_loss / len(train_loader)

print(‘epoch:{},loss:{}’.format(epoch, ave_loss))

验证过程

def val(model, device, test_loader):

model.eval()

test_loss = 0

correct = 0

total_num = len(test_loader.dataset)

print(total_num, len(test_loader))

with torch.no_grad():

for data, target in test_loader:

data, target = Variable(data).to(device), Variable(target).to(device)

output = model(data)

loss = criterion(output, target)

_, pred = torch.max(output.data, 1)

correct += torch.sum(pred == target)

print_loss = loss.data.item()

test_loss += print_loss

correct = correct.data.item()

acc = correct / total_num

avgloss = test_loss / len(test_loader)

print(‘\nVal set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n’.format(

avgloss, correct, len(test_loader.dataset), 100 * acc))

训练

for epoch in range(1, EPOCHS + 1):

adjust_learning_rate(optimizer, epoch)

train(model_ft, DEVICE, train_loader, optimizer, epoch)

val(model_ft, DEVICE, test_loader)

torch.save(model_ft, ‘model.pth’)

测试

我介绍两种常用的测试方式,第一种是通用的,通过自己手动加载数据集然后做预测,具体操作如下:

测试集存放的目录如下图:

第一步 定义类别,这个类别的顺序和训练时的类别顺序对应,一定不要改变顺序!!!!

第二步 定义transforms,transforms和验证集的transforms一样即可,别做数据增强。

第三步 加载model,并将模型放在DEVICE里,

第四步 读取图片并预测图片的类别,在这里注意,读取图片用PIL库的Image。不要用cv2,transforms不支持。

import torch.utils.data.distributed

import torchvision.transforms as transforms

from PIL import Image

from torch.autograd import Variable

import os

classes = (‘Black-grass’, ‘Charlock’, ‘Cleavers’, ‘Common Chickweed’,

‘Common wheat’,‘Fat Hen’, ‘Loose Silky-bent’,

‘Maize’,‘Scentless Mayweed’,‘Shepherds Purse’,‘Small-flowered Cranesbill’,‘Sugar beet’)

transform_test = transforms.Compose([

transforms.Resize((224, 224)),

transforms.ToTensor(),

transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])

])

DEVICE = torch.device(“cuda:0” if torch.cuda.is_available() else “cpu”)

model = torch.load(“model.pth”)

model.eval()

model.to(DEVICE)

path=‘data/test/’

testList=os.listdir(path)

for file in testList:

img=Image.open(path+file)

img=transform_test(img)

img.unsqueeze_(0)

img = Variable(img).to(DEVICE)

out=model(img)

Predict

_, pred = torch.max(out.data, 1)

print(‘Image Name:{},predict:{}’.format(file,classes[pred.data.item()]))

第二种 使用自定义的Dataset读取图片

import torch.utils.data.distributed

import torchvision.transforms as transforms

from dataset.dataset import SeedlingData

from torch.autograd import Variable

classes = (‘Black-grass’, ‘Charlock’, ‘Cleavers’, ‘Common Chickweed’,

‘Common wheat’,‘Fat Hen’, ‘Loose Silky-bent’,

‘Maize’,‘Scentless Mayweed’,‘Shepherds Purse’,‘Small-flowered Cranesbill’,‘Sugar beet’)

transform_test = transforms.Compose([

transforms.Resize((224, 224)),

transforms.ToTensor(),

transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])

])

DEVICE = torch.device(“cuda:0” if torch.cuda.is_available() else “cpu”)

model = torch.load(“model.pth”)

model.eval()

model.to(DEVICE)

dataset_test =SeedlingData(‘data/test/’, transform_test,test=True)

print(len(dataset_test))

对应文件夹的label

for index in range(len(dataset_test)):

item = dataset_test[index]

img, label = item

img.unsqueeze_(0)

data = Variable(img).to(DEVICE)

output = model(data)

_, pred = torch.max(output.data, 1)

print(‘Image Name:{},predict:{}’.format(dataset_test.imgs[index], classes[pred.data.item()]))

index += 1

完整代码

====

train.py

import torch.optim as optim

import torch

import torch.nn as nn

import torch.nn.parallel

import torch.utils.data

import torch.utils.data.distributed

import torchvision.transforms as transforms

from dataset.dataset import SeedlingData

from torch.autograd import Variable

from torchvision.models import vgg16

设置全局参数

modellr = 1e-4

BATCH_SIZE = 32

EPOCHS = 10

DEVICE = torch.device(‘cuda’ if torch.cuda.is_available() else ‘cpu’)

文末有福利领取哦~

👉一、Python所有方向的学习路线

Python所有方向的技术点做的整理,形成各个领域的知识点汇总,它的用处就在于,你可以按照上面的知识点去找对应的学习资源,保证自己学得较为全面。img

👉二、Python必备开发工具

img
👉三、Python视频合集

观看零基础学习视频,看视频学习是最快捷也是最有效果的方式,跟着视频中老师的思路,从基础到深入,还是很容易入门的。
img

👉 四、实战案例

光学理论是没用的,要学会跟着一起敲,要动手实操,才能将自己的所学运用到实际当中去,这时候可以搞点实战案例来学习。(文末领读者福利)
img

👉五、Python练习题

检查学习结果。
img

👉六、面试资料

我们学习Python必然是为了找到高薪的工作,下面这些面试题是来自阿里、腾讯、字节等一线互联网大厂最新的面试资料,并且有阿里大佬给出了权威的解答,刷完这一套面试资料相信大家都能找到满意的工作。
img

img

👉因篇幅有限,仅展示部分资料,这份完整版的Python全套学习资料已经上传

网上学习资料一大堆,但如果学到的知识不成体系,遇到问题时只是浅尝辄止,不再深入研究,那么很难做到真正的技术提升。

需要这份系统化学习资料的朋友,可以戳这里无偿获取

一个人可以走的很快,但一群人才能走的更远!不论你是正从事IT行业的老鸟或是对IT行业感兴趣的新人,都欢迎加入我们的的圈子(技术交流、学习资源、职场吐槽、大厂内推、面试辅导),让我们一起学习成长!

  • 13
    点赞
  • 21
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值