Pytorch---使用Pytorch进行图像定位

一、代码中的数据集可以通过以下链接获取

百度网盘提取码:vc56

二、代码运行环境

Pytorch-gpu==1.10.1
Python==3.8

三、数据集处理代码如下所示

import os
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset, DataLoader
import xml.etree.ElementTree as ET
import numpy as np
from torchvision import transforms
from PIL import Image
from torchvision.utils import draw_bounding_boxes


class PetDataset(Dataset):
    def __init__(self, images_path, labels, transform):
        super(PetDataset, self).__init__()
        self.images_path = images_path
        self.labels = labels
        self.transform = transform

    def __getitem__(self, index):
        img = self.images_path[index]
        pil_img = Image.open(img).convert('RGB')
        img_tensor = self.transform(pil_img)

        label1, label2, label3, label4 = self.labels[index]

        return img_tensor, label1, label2, label3, label4

    def __len__(self):
        return len(self.images_path)


def to_labels(path):
    xml_file = open(path, encoding='utf-8')
    tree = ET.parse(xml_file)
    root = tree.getroot()

    width = float(root.find('size').find('width').text)
    height = float(root.find('size').find('height').text)

    xmin = float(root.find('object').find('bndbox').find('xmin').text) / width
    ymin = float(root.find('object').find('bndbox').find('ymin').text) / height
    xmax = float(root.find('object').find('bndbox').find('xmax').text) / width
    ymax = float(root.find('object').find('bndbox').find('ymax').text) / height

    return [xmin, ymin, xmax, ymax]


def load_data():
    DATASET_PATH = r'/Users/leeakita/Desktop/dataset'
    BATCH_SIZE = 32

    XML_PATH = os.path.join(DATASET_PATH, 'xmls')
    IMAGE_PATH = os.path.join(DATASET_PATH, 'images')

    xml_names = os.listdir(XML_PATH)
    file_names = [name.split('.')[0] for name in xml_names]

    image_paths = [os.path.join(IMAGE_PATH, file_name + '.jpg') for file_name in file_names]
    xml_paths = [os.path.join(XML_PATH, file_name + '.xml') for file_name in file_names]
    labels = [to_labels(xml_path) for xml_path in xml_paths]

    np.random.seed(2022)
    index = np.random.permutation(len(image_paths))

    image_paths = np.array(image_paths)[index]
    labels = np.array(labels)[index]
    labels = labels.astype(np.float32)

    train_split = int(len(image_paths) * 0.8)

    train_images = image_paths[:train_split]
    train_labels = labels[:train_split]

    test_images = image_paths[train_split:]
    test_labels = labels[train_split:]

    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])

    train_ds = PetDataset(images_path=train_images, labels=train_labels, transform=transform)
    test_ds = PetDataset(images_path=test_images, labels=test_labels, transform=transform)

    train_dl = DataLoader(dataset=train_ds, batch_size=BATCH_SIZE, shuffle=True)
    test_dl = DataLoader(dataset=test_ds, batch_size=BATCH_SIZE)

    return train_dl, test_dl


if __name__ == '__main__':
    trainn, testt = load_data()
    image, xminn, yminn, xmaxn, ymaxn = next(iter(testt))
    index = 6
    image, xminn, yminn, xmaxn, ymaxn = image[index], xminn[index], yminn[index], xmaxn[index], ymaxn[index]
    boxes = [xminn.item() * 224, yminn.item() * 224, xmaxn.item() * 224, ymaxn.item() * 224]
    boxes = torch.FloatTensor(boxes)
    boxes = boxes.unsqueeze(0)
    result = draw_bounding_boxes(image=torch.as_tensor(data=image * 255, dtype=torch.uint8), boxes=boxes, colors='red')
    plt.imshow(result.permute(1, 2, 0).numpy())
    plt.show()

四、模型的构建代码如下所示

import torch
import torchvision
from torch import nn


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        resnet = torchvision.models.resnet101(pretrained=True)
        self.conv_base = nn.Sequential(*list(resnet.children())[:-1])
        self.fc1 = nn.Linear(in_features=resnet.fc.in_features, out_features=1)
        self.fc2 = nn.Linear(in_features=resnet.fc.in_features, out_features=1)
        self.fc3 = nn.Linear(in_features=resnet.fc.in_features, out_features=1)
        self.fc4 = nn.Linear(in_features=resnet.fc.in_features, out_features=1)

    def forward(self, x):
        x = self.conv_base(x)
        x = torch.squeeze(x)
        x1 = self.fc1(x)
        x2 = self.fc2(x)
        x3 = self.fc3(x)
        x4 = self.fc4(x)

        return x1, x2, x3, x4


if __name__ == '__main__':
    model = Net()

五、模型的训练代码如下所示

import numpy as np
import torch
from data_loader import load_data
from model_loader import Net
import tqdm
import os

# 环境变量的配置
devices = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 数据的加载
train_dl, test_dl = load_data()

# 模型的加载
model = Net()
model.to(device=devices)

# 训练的相关配置
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.0001)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optimizer, step_size=7, gamma=0.7)

# 开始进行训练
for epoch in range(50):
    model.train()
    train_tqdm = tqdm.tqdm(iterable=train_dl, total=len(train_dl))
    train_tqdm.set_description_str('Train epoch: {:2d}'.format(epoch))
    train_loss_sum = []
    for image, xmin, ymin, xmax, ymax in train_tqdm:
        image, xmin, ymin, xmax, ymax = image.to(devices), xmin.to(devices), ymin.to(devices), xmax.to(
            devices), ymax.to(devices)

        pred_xmin, pred_ymin, pred_xmax, pred_ymax = model(image)
        pred_xmin, pred_ymin, pred_xmax, pred_ymax = pred_xmin.squeeze(), pred_ymin.squeeze(), pred_xmax.squeeze(), pred_ymax.squeeze()

        loss_xmin = loss_fn(pred_xmin, xmin)
        loss_ymin = loss_fn(pred_ymin, ymin)
        loss_xmax = loss_fn(pred_xmax, xmax)
        loss_ymax = loss_fn(pred_ymax, ymax)

        loss = loss_xmin + loss_ymin + loss_xmax + loss_ymax

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        with torch.no_grad():
            train_loss_sum.append(loss.item())
        train_tqdm.set_postfix_str('loss is :{:14f}'.format(np.array(train_loss_sum).mean()))
    train_tqdm.close()

    with torch.no_grad():
        model.eval()
        test_tqdm = tqdm.tqdm(iterable=test_dl, total=len(test_dl))
        test_tqdm.set_description_str('Test epoch: {:2d}'.format(epoch))
        test_loss_sum = []
        for image, xmin, ymin, xmax, ymax in test_tqdm:
            image, xmin, ymin, xmax, ymax = image.to(devices), xmin.to(devices), ymin.to(devices), xmax.to(
                devices), ymax.to(devices)

            pred_xmin, pred_ymin, pred_xmax, pred_ymax = model(image)
            pred_xmin, pred_ymin, pred_xmax, pred_ymax = pred_xmin.squeeze(), pred_ymin.squeeze(), pred_xmax.squeeze(), pred_ymax.squeeze()

            loss_xmin = loss_fn(pred_xmin, xmin)
            loss_ymin = loss_fn(pred_ymin, ymin)
            loss_xmax = loss_fn(pred_xmax, xmax)
            loss_ymax = loss_fn(pred_ymax, ymax)

            loss = loss_xmin + loss_ymin + loss_xmax + loss_ymax

            test_loss_sum.append(loss.item())
            test_tqdm.set_postfix_str('loss is :{:14f}'.format(np.array(test_loss_sum).mean()))
        test_tqdm.close()

# 进行模型的保存
if not os.path.exists(os.path.join('model_data')):
    os.mkdir(os.path.join('model_data'))
torch.save(model.state_dict(), os.path.join('model_data', 'model.pth'))

六、模型的预测代码如下所示

import torch
from data_loader import load_data
from model_loader import Net
import os
from torchvision.utils import draw_bounding_boxes
import matplotlib.pyplot as plt

# 数据的加载
train_dl, test_dl = load_data()
image, xmin, ymin, xmax, ymax = next(iter(test_dl))

# 模型的加载
model = Net()
model_state_dict = torch.load(os.path.join('model_data', 'model.pth'), map_location='cpu')
model.load_state_dict(model_state_dict)
model.eval()

# 开始进行预测
index = 0
with torch.no_grad():
    pred_xmin, pred_ymin, pred_xmax, pred_ymax = model(image)
    pred_xmin, pred_ymin, pred_xmax, pred_ymax = pred_xmin[index], pred_ymin[index], pred_xmax[index], pred_ymax[index]
    pre_boxes = [pred_xmin.item() * 224, pred_ymin.item() * 224, pred_xmax.item() * 224, pred_ymax.item() * 224]
    pre_boxes = torch.FloatTensor(pre_boxes)
    pre_boxes = torch.unsqueeze(input=pre_boxes, dim=0)
    label_boxes = [xmin[index].item() * 224, ymin[index].item() * 224, xmax[index].item() * 224,
                   ymax[index].item() * 224]
    label_boxes = torch.FloatTensor(label_boxes)
    label_boxes = torch.unsqueeze(input=label_boxes, dim=0)
    img = image[index]
    img = torch.as_tensor(data=img * 255, dtype=torch.uint8)
    result = draw_bounding_boxes(image=img, boxes=pre_boxes, colors='red')
    result = draw_bounding_boxes(image=result, boxes=label_boxes, colors='blue')
    plt.figure(figsize=(8, 8), dpi=500)
    plt.axis('off')
    plt.imshow(result.permute(1, 2, 0))
    plt.savefig('result.png')
    plt.show()

七、代码的运行结果如下所示

在这里插入图片描述

  • 1
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

水哥很水

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值