初学全连接神经网络项目笔记——飞机识别(客机、战斗机、直升机)

1.数据准备

1.爬虫代码爬取百度图库的图片

import re
import threading
import time
from threading import Thread, Lock
import requests
# 关键字
keyword = '战斗机'
# 存放的目录
img_dir = 'images/zhandou/'
# 每一页的图片数量
page_num = 30
# 爬取的地址
urls = [f'https://image.baidu.com/search/flip?tn=baiduimage&ie=utf-8&word="{keyword}"&pn={30 * index}'
        for index in range(page_num)]
# 统计图片数量
lock = Lock()
image_count = 0
class Spider(Thread):
    def __init__(self, name):
        super(Spider, self).__init__()
        self.name = name
    """
        下载数据
        1. headers
        2. 发起请求
    """
    def down_load(self, url):
        global image_count
        headers = {
            'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/112.0.0.0 Safari/537.36'
        }
        result = requests.get(url, timeout=10, headers=headers)
        """
        # print(result.text)
        # .除\n(换行符)之外的任意字符
        # *匹配0 - 无穷次
        # ?非贪婪模式(一旦拿到数据直接返回,不再向下匹配)
        
        "objURL":"https://imgx.xiawu.com/xzimg/i4/i2/TB1FDR2FVXXXXXMXpXXXXXXXXXX_%21%210-item_pic.jpg",
        'https://image.baidu.com/search/flip?tn=baiduimage&ie=utf-8&word="草莓"&pn=0'
        """
        # 需要匹配包括换行符在内的所有字符。这时可以使用re.S标志
        image_urls = re.findall('"objURL":"(.*?)",', result.text, re.S)
        print(image_urls)
        """
        请求图片
        保存图片
         图片命名 序号+格式
        """
        for image_url in image_urls:
            try:
                image_name = str(int(time.time() * 1000000)) + '.jpg'
                image_url = image_url.strip('"')
                image_url = image_url.strip("'")
                pic = requests.get(image_url, timeout=7)
                img_path = img_dir + image_name
                fp = open(img_path, 'wb')
                fp.write(pic.content)
                fp.close()
                image_count += 1
                name = threading.current_thread().name
                print(f'线程:{name} {image_name}保存成功 第{image_count}张')
            except:
                print(f'{image_name}出错啦')

    def run(self):
        global urls
        while True:
            lock.acquire()
            if len(urls) == 0:
                print('.............没有数据啦.............')
                lock.release()
                return
            """
            在Python中,列表是线程不安全的数据结构,因为它不是线程安全的。
            这意味着,在多线程环境中,如果多个线程同时对同一个列表进行操作,可能会导致竞态条件和数据不一致问题
            """
            url = urls[0]
            # 这里多一句代码就会演示出问题 出现线程安全问题
            # print('--------')
            del urls[0]
            name = threading.current_thread().name
            print(f'{name} 获取了数据{url}')
            # time.sleep(0.1)
            lock.release()
            self.down_load(url)


if __name__ == '__main__':
    """
        输入需要爬取的页数
        1. 输入需要爬取的页数
        2. 每页返回60个数据
        3. 请求数据
    """
    # page_num = int(input('请输入需要爬取的页数:'))
    # pn  页数
    # 0  第1页
    # 20 第2页
    # 40 第3页
    t1 = time.time()
    queue = []
    for index in range(3):
        spider = Spider(f'th-{index}')
        spider.start()
        queue.append(spider)
    for spider in queue:
        spider.join()
    t2 = time.time()
    # 4个线程
    # 结束用时:76.23213601112366
    # 1个线程
    print(f'结束用时:{t2-t1}')

2.处理爬取的图片

​ 2.1批量删除打不开的文件。由于爬取的图片有很多打开不了,自己写个了工具批量删除那些打开不了的图片

"""
由于爬取的图片中有很多无法打开,进行批量删除
"""
import glob
import cv2
import os

class RemoveBadImage:
    def Remvoe(self,path):
        img_paths = glob.glob(os.path.join(path,"*"))
        for img_path in img_paths:
            im = cv2.imread(img_path)
            try:
                #打不开的图片,opencv读取为NoneType,对其进行操作时会异常
                
                im.shape
            except Exception as e :
                捕获到异常时就删除该路径的图片
                os.remove(img_path)
                print(f"{img_path}已经删除")
        print("------")
path1 = "../images/minhang"
path2 = "../images/zhandou"
path3 = "../images/zhisheng"
r = RemoveBadImage()
r.Remvoe(path1)
r.Remvoe(path2)
r.Remvoe(path3)

​ 2.2手动删除无关图片

​ 由于下载的图片中还会有在这里插入图片描述
请添加图片描述

3.最终得到的数据集概览

在这里插入图片描述

在这里插入图片描述

2制作数据集

import glob
import os.path

import cv2
import torch
from torch.utils.data import Dataset

class LoadImageAndLabels(Dataset):
    def __init__(self,default_path = r"images",is_train = True):
        super().__init__()
        data = []
        self.data = data
        #获取图片的路径
        image_type = os.path.join(default_path,"train" if is_train else "test")
        image_path = os.path.join(image_type,"*","*")
        image_paths = glob.glob(image_path)
        #获取图片和分类标签
        #0 民航飞机
        #1 战斗机
        #2 直升机
        for img_path in image_paths:
            #取得图像的标签
            label = int(img_path.split("\\")[-2])

            #添加到data中
            data.append((img_path,label))
            # print(im_tensor_v,label)
    def __len__(self):
        return len(self.data)

    def __getitem__(self, index ):
        img_path ,label = self.data[index]
        im = cv2.imread(img_path)
        # 重置图片大小,宽高比 3:2
        im = cv2.resize(im, (100, 100))
        # 图片灰度化和归一化
        im_ = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY) / 255
        # 转为张量
        im_tensor = torch.tensor(im_, dtype=torch.float32)
        # 降维
        im_tensor_v = torch.flatten(im_tensor)
        #one—hot编码
        one_hot = torch.zeros(3)
        one_hot[label] = 1
        return im_tensor_v,one_hot

起初,我把图像的读取、灰度化等操作放在构造函数中,这导致我运行程序时启动的很慢,将这些操作的代码写在getitem函数中启动时间明显变快。经过思考后得出导致启动慢的原因。构造函数中若写了图像的读取、灰度化等操作,在创建LoadImageAndLabels时,会执行完这些操作后才执行之后的代码。然后,我们调用这些数据不需要全部调入处理,只需要分批次地处理,故造成了我运行项目时启动慢。

3.构造网络

"""
搭建网络

"""

"""
搭建网络

"""
import torch
from torch import nn

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(100*100,1024),
            nn.ReLU(),
            nn.Linear(1024,856),
            nn.ReLU(),
            nn.Linear(856,256),
            nn.ReLU(),
            nn.Linear(256,64),
            nn.ReLU(),
            nn.Linear(64,3),

        )
        self.softmax = nn.Softmax(dim= 1)

    def forward(self,x):
        return self.softmax(self.layers(x))

if __name__ == '__main__':
    net = Net()
    x = torch.rand(1,100*100)

    print(net(x).shape)



4.训练模型

import torch.optim
import tqdm

from mydataset import LoadImageAndLabels
from net import Net
from torch.utils.data import DataLoader
from torch import nn
from torch.utils.tensorboard import SummaryWriter
best_path = f"weights/best.pt"
test_accuracy_f = -1
device = "cpu"
class Trainer:
    def __init__(self):
        self.writer = SummaryWriter(f"logs")
        #准备训练集和测试集
        train_set = LoadImageAndLabels()
        test_set = LoadImageAndLabels(is_train= False)
        #加载数据
        self.train_loader = DataLoader(train_set,batch_size=30,shuffle=True)
        self.test_loader = DataLoader(test_set,batch_size=10,shuffle=True)
        #初始化网络
        net = Net()
        self.net = net
        #装载参数
        try:
            net.load_state_dict(torch.load(best_path))
        except Exception as e:
            print(e)
        net.to(device = device)
        #损失函数
        self.loss_fn = nn.MSELoss()

        #优化器
        self.opt = torch.optim.Adam(net.parameters(),lr=0.02)

    def train(self, epoch):
        sum_loss = 0
        sum_acc = 0

        for img_tensors ,targets in tqdm.tqdm(self.train_loader,desc="Training....",total=len(self.test_loader)):
            img_tensors = img_tensors.to(device)
            targets = targets.to(device)
            #向前传播
            outputs = self.net(img_tensors)
            #损失
            loss = self.loss_fn(outputs,targets)
            sum_loss += loss.item()

            #预测结果
            pre_cls = torch.argmax(outputs,dim=1)
            targets_cls = torch.argmax(targets,dim=1)

            #准确率
            accuracy = torch.mean(torch.eq(pre_cls,targets_cls).to(torch.float32))
            sum_acc += accuracy.item()

            #梯度清零
            self.opt.zero_grad()
            #反向传播
            loss.backward()
            #梯度更新
            self.opt.step()

        avg_loss = sum_loss/len(self.train_loader)
        avg_acc = sum_acc/len(self.train_loader)
        #添加记录
        self.writer.add_scalars("loss", {"train_avg_loss": avg_loss}, epoch)
        self.writer.add_scalars("acc", {"train_avg_acc": avg_acc}, epoch)
        print(f"train:loss:{round(avg_loss,3)} acc:{round(avg_acc,3)}")

    def test(self, epoch):
        sum_loss = 0
        sum_acc = 0

        for img_tensors,targets in tqdm.tqdm(self.test_loader,desc="Testing....",total=len(self.test_loader)):
            img_tensors = img_tensors.to(device)
            targets = targets.to(device)
            outputs = self.net(img_tensors)

            loss = self.loss_fn(outputs,targets)
            sum_loss += loss.item()

            pre_cls = torch.argmax(outputs,dim=1)
            targets_cls = torch.argmax(targets,dim=1)
            accuracy = torch.mean(torch.eq(pre_cls,targets_cls).to(torch.float32))
            sum_acc += accuracy.item()
"""
这里添加记录和计算评价损失和准确率代码缩进写错了,导致图像反馈的结果很奇怪
"""
            avg_loss = sum_loss/len(self.test_loader)
            avg_acc = sum_acc/len(self.test_loader)
        # 添加记录
            self.writer.add_scalars("loss", {"test_avg_loss": avg_loss}, epoch)
            self.writer.add_scalars("acc", {"test_avg_acc": avg_acc}, epoch)
            if avg_acc > test_accuracy_f:
                torch.save(self.net.state_dict(),f"weights/best.pt")
            print(f"test:loss:{round(avg_loss, 3)} acc:{round(avg_acc, 3)}")

    def run(self):
        for epoch in range(30):
            self.train(epoch)
            self.test(epoch)

if __name__ == '__main__':
    trainer = Trainer()
    trainer.run()



由于这里添加记录和计算评价损失和准确率代码位置写错,导致下图的奇怪结果
在这里插入图片描述

这是由于计算平均损失和计算平均准确率的代码写在了循环内部,在训练和测试时,每做一次训练就计算一次平均损失和平均准确度,而不是每一轮porch一次。

更改代码后

在这里插入图片描述

训练没有啥效果,我认为是数据集的问题,对爬取的图片进行了抠图,去除背景。我使用的是WPS的图片处理器,它能够批量化操作,每批最大数量为500。

在这里插入图片描述

抠完图后的训练集概览:

在这里插入图片描述

在这里插入图片描述

当然,其中还有一部分扣图的效果不理想,我就手动处理或删除。还有部分图片是重复的,由于数据集比较少(每个类别大概再三四百张)我将其旋转作为新的数据集。

抠图后,训练图结果


在这里插入图片描述

更换优化器,Adagrad 是一种自适应学习率算法,它会根据参数的历史梯度调整学习率,对于稀疏数据或者参数有不同尺度的情况下比较适用。

self.opt = torch.optim.Adagrad(net.parameters(),lr=0.005)

更换后的结果如下,从折线图看得出来与优化器关系不大。
在这里插入图片描述

从结果来看,我认为是自己的数据集不的问题,于是从开源数据库获取数据。这是在AI studio上找到的饿有关飞机的数据集,我只取了其中客机,战斗机和直升机的数据。

在这里插入图片描述

!

该数据集中,有很多类别,我只选用了客机、战斗机和直升机。训练集中的图片大小都是224*224三通道的图,图片无变形,每个类别的图片数量在3000左右。预期会有较好的效果。

在这里插入图片描述

测试集则是大小不一的彩图。
在这里插入图片描述

但是,在我换数据集后,报错如下图。
在这里插入图片描述

看到错误信息,opencv读取的图片是空的,首先想到在数据测试集中debug,看数据集中能否正确的读取到文件。

测试代码:

class LoadImageTest(Dataset):
    def __init__(self,defalt_path = "D:\\py\\datasets\\train\\keji"):
        super().__init__()
        data = []
        self.data = data
        img_paths = glob.glob(os.path.join(defalt_path,"*"))

        for img_path in img_paths:
            data.append(img_path)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, indx):
        img_path = self.data[indx]
        print(img_path)
        im = cv2.imread(img_path)
        # 重置图片大小
        im = cv2.resize(im, (100, 100))
        # 图片灰度化和归一化
        im_ = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY) / 255
        # 转为张量
        im_tensor = torch.tensor(im_, dtype=torch.float32)
        # 降维
        im_tensor_v = torch.flatten(im_tensor)
        return im_tensor_v

if __name__ == '__main__':
    l = LoadImageTest()
    #print(l[0])
    print(l[800])

测试的时候读取一张时正常

在这里插入图片描述

正常打印

在这里插入图片描述

当文件中有中文时,opencv就读取不到。并且报错信息和我换数据集后训练报的一样。

在这里插入图片描述

在这里插入图片描述

根据以上信息,我想删除图片名称中含有中文的图片。就进行了如下操作。

捕获异常的同时删除文件

    def __getitem__(self, index ):
        img_path ,label = self.data[index]
        try:
            im = cv2.imread(img_path)
            # 重置图片大小
            im = cv2.resize(im, (100, 100))
            # 图片灰度化和归一化
            im_ = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY) / 255
            # 转为张量
            im_tensor = torch.tensor(im_, dtype=torch.float32)
            # 降维
            im_tensor_v = torch.flatten(im_tensor)
            # one—hot编码
            one_hot = torch.zeros(3)
            one_hot[label] = 1
            return im_tensor_v, one_hot
        except Exception as e:
            print(e)
            os.remove(img_path)
            print(img_path+"已经删除")

删除结果

在这里插入图片描述

但是,因为客机(民航飞机)这个类别的数据集中很多图片名称都含有中文,都删除了会导致训练集数量锐减。我想到了批量将文件重命名的办法。

解决方法,将所有文件重新命名,在image-tool工具中写了一个重新命名的方法

def rename():
    i = 0
    file_path = []
    #由于只有这个文件夹中的文件含有中文名,故路径写死了
    path = "D:\\py\\plane\\images\\train\\0"

    for img_path in os.listdir(path):
        old_filename = os.path.join(path,img_path)
        new_filename = os.path.join(path, "m"+str(i) + ".jpg")
        os.rename(old_filename, new_filename)
        i = i + 1
        print("已经完成一次修改")

rename()

修改文件名后正常训练

在这里插入图片描述

loss acc 趋势 此时的学习率为0.01

在这里插入图片描述

调整学习率,降低为0.0001
在这里插入图片描述

发现模型过拟合,我认为数据集过少,又将数据集中的图片向水平方向翻转后添加到训练集中,以此来增加训练集。

写的工具:

def hcvtimg(path:str):
    for img_path in os.listdir(path):
        im = cv2.imread(os.path.join(path,img_path))
        #水平方向反转
        hor_img = cv2.flip(im,1)
        cv2.imwrite(os.path.join(path,"hor"+img_path),hor_img)
        print("成功保存一张水平翻转图")

增加训练集后。此时的学习率为0.0001,测试集的准确率稳定在80%.

在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值