【个人学习】使用torch实现xml文件返回的图片

个人学习使用,若有错误,敬请指正

主要实现

构建数据集类,使用xml的文件信息去提取车牌特征,转换尺寸之后输入神经网络。

因为没有搜索到纯代码实现xml的图片内容,所以自己尝试使用torch框架并实现。

实现代码

导入包和定义文件路径

import os 
import matplotlib.pyplot as plt 
from PIL import Image
import random
import xml.etree.ElementTree as ET
import torch
import torch.optim as optim 
from torchvision import transforms, datasets,utils
from torch.utils.data import DataLoader, Dataset
from skimage import io, transform
import glob 
import torchvision 
import numpy as np 
plt.rcParams["font.sans-serif"]=["SimHei"] #设置字体
plt.rcParams["axes.unicode_minus"]=False #该语句解决图像中的缺失中文

# 定义训练集和测试集的路径
trainData_dir = 'Train/images'
testData_dir = 'Test/images'
# 引入Annotations文件来增强特征提取
trainAnno_dir = 'Train/annotations'
testAnno_dir = 'Train/annotations'

images= os.listdir(trainData_dir)
# print(f'文件夹下总共有{len(images)}条数据')

# 构建图片名称和索引的映射
name_map = {}
for i in range(len(images)):
    # print(images[i][:-4])
    name_map[i] =  images[i][:-4]

print(len(name_map))

自定义数据集的类

'''https://pytorch.org/tutorials/beginner/data_loading_tutorial.html
'''
class XMLDataset(Dataset):
    def __init__(self,trainData_dir,trainAnno_dir,name_map,transform):
        # 图像总路径
        self.img_path = trainData_dir
        # 注解总路径
        self.ann_path = trainAnno_dir
        # 总名称路径
        self.name_map = name_map
        self.transform = transform


        
    def __len__(self):
        '''返回数据集的总长度'''
        return len(self.name_map)

    def __getitem__(self, idx):
        '''以支持索引,使得数据集[i]可以用于获得第i个样本。'''
        # 输入索引,转换为名称
        idx = self.name_map[idx]
        img_name = idx+'.png'
        ann_name = idx+'.xml'

        # 图片完整路径
        img_path = os.path.join(self.img_path,img_name)
        # 注解完整路径
        ann_path = os.path.join(self.ann_path,ann_name)

        # 对图片进行xml特征提取
        img = Image.open(img_path).convert('RGB')


        # 加载图片,从xml文件中获得bbox的数组然后返回图片矩阵
        tree = ET.parse(ann_path)
        root = tree.getroot()
        objects = root.findall('object')
        # 从xml中查看属性名称
        for obj in objects:
            bounding_box = obj.find('bndbox')
            # 方法2
            xmin = int(bounding_box[0].text) 
            ymin = int(bounding_box[1].text)
            xmax = int(bounding_box[2].text)
            ymax = int(bounding_box[3].text)  
            # label = obj.find('name').text 
            # label = obj.find('truncated').text

        bbox = (xmin, ymin, xmax, ymax)
        # 转换尺寸
        img = img.crop(bbox)
        
        img = img.resize((64,64))

        img = np.array(img,dtype=np.float32) # unsinged char 和float、
        '''
        swap color axis because
        # numpy image: H x W x C
        # torch image: C x H x W
        '''
        img = img.transpose((2, 0, 1))# 转换格式 
        
        return img,torch.tensor(0)
    

使用数据集加载器和定义模型

dataloader  = DataLoader(license_dataset,batch_size=4,shuffle=True,num_workers=0)

# 测试在批量中的图片信息
# for i_batch,sample_batched in enumerate(dataloader):
#     # print(i_batch, sample_batched['image'].size(),
#     #       sample_batched['label'].size())
    
#     # observe 4th batch and stop.
#     if i_batch == 3:
#         plt.figure()
        
#         show_crop_batch(sample_batched)
#         plt.axis('off')
#         plt.ioff()
#         plt.show()
#         break



''' 定义模型'''
import torch.nn as nn 
import torch.nn.functional as F 


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        '''
        in_channels是输入的四维张量[N, C, H, W]中的C了,即输入张量的channels数。这个形参是确定权重等可学习参数的shape所必需的。
        
        torch的张量输入是[Num,channels,Height,Weight]

        '''

        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)


        # self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc1 = nn.Linear(2704, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 1)

    def forward(self, x):


        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)

        return x

初始化模型并训练

net = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(),lr=0.001,momentum=0.9)
# 训练网络
N_EPOCHS=5 
for epoch in range(N_EPOCHS):
    running_loss = 0.0 
    # for i,data in enumerate([image,'licence'],0):
    for i,data in enumerate(dataloader,0):
        # print(type(data))# 我的返回是字典


        # 获得输入,data的形状是[inputs,labels]
        inputs,labels = data[0],data[1]
        
        # print(f'inputs_shape',inputs.shape)
        # print(f"inputs={inputs},labels={labels}")
        # print('inputs_shape,length of inputs',inputs.shape,len(inputs))
        # print(type(inputs),inputs.dtype) # 后者查看具体类型
        # inputs = inputs.astype('float')
        
        # 零梯度参数
        optimizer.zero_grad()

        # 前向传播+后向传播+优化
        outputs = net(inputs)
        # print(outputs)

        loss = criterion(outputs,labels)
        loss.backward()
        optimizer.step()

        # 打印静态信息
        running_loss += loss.item()
        if i % 100 == 1:    # print every 2000 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 100:.3f}')
            running_loss = 0.0

print('Finished Training')

实现截图

在这里插入图片描述

参考博客

1.https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
2.https://pytorch.org/tutorials/beginner/data_loading_tutorial.html
3.https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html
4.https://www.kaggle.com/code/bitthal/understanding-input-data-and-loading-with-pytorch

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值