Pytorch入门进行迁移学习实现自行车分类识别:获取数据集与准备数据

前言

迁移学习是一种机器学习方法,它利用已经训练好的模型在新任务上进行训练,从而提高模型的性能和泛化能力。在本文中,我们将使用PyTorch实现一个基于预训练模型的迁移学习模型,用于单车分类识别。

项目概述

我们的目标是创建一个能够识别不同类型自行车的图像分类模型。为实现这一目标,我们首先需要获取一个包含大量自行车图片的数据集。由于公开可用的数据集可能不完全满足特定需求,我们决定使用爬虫技术从互联网上抓取自行车图片。

数据集爬取

使用Python的requests库构建网络爬虫,从网页中提取图片。

def get_images_from_baidu(keyword, page_num, save_dir):
    header = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/78.0.3904.108 Safari/537.36'}
    # 请求的 url
    url = 'https://image.baidu.com/search/acjson?'
    n = 0
    for pn in range(0, 30 * page_num, 30):
        # 请求参数
        param = {'tn': 'resultjson_com',
                 'logid': '7603311155072595725',
                 'ipn': 'rj',
                 'ct': 201326592,
                 'is': '',
                 'fp': 'result',
                 'queryWord': keyword,
                 'cl': 2,
                 'lm': -1,
                 'ie': 'utf-8',
                 'oe': 'utf-8',
                 'adpicid': '',
                 'st': -1,
                 'z': '',
                 'ic': '',
                 'hd': '',
                 'latest': '',
                 'copyright': '',
                 'word': keyword,
                 's': '',
                 'se': '',
                 'tab': '',
                 'width': '',
                 'height': '',
                 'face': 0,
                 'istype': 2,
                 'qc': '',
                 'nc': '1',
                 'fr': '',
                 'expermode': '',
                 'force': '',
                 'cg': '',    # 这个参数没公开,但是不可少
                 'pn': pn,    # 显示:30-60-90
                 'rn': '30',  # 每页显示 30 条
                 'gsm': '1e',
                 '1618827096642': ''
                 }
        request = requests.get(url=url, headers=header, params=param)
        if request.status_code == 200:
            print('Request success.')
        request.encoding = 'utf-8'
        # 正则方式提取图片链接
        html = request.text
        image_url_list = re.findall('"thumbURL":"(.*?)",', html, re.S)

        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        for image_url in image_url_list:
            image_data = requests.get(url=image_url, headers=header).content
            with open(os.path.join(save_dir, f'{n:06}.jpg'), 'wb') as fp:
                fp.write(image_data)
            n = n + 1  

我们将单车分成5类:hello(哈罗单车), meituan(美团单车), qingju(青桔单车),zijiadanche(自用单车),qita(其它单车)

数据清理

获取到图片后,需要对数据进行清洗,去除不相关或质量不高的图片进行过滤。

对数据进行标注

针对于不同的类别,标注数据集。

train_name_file = open("data/biycle/train.txt", "w")
test_name_file = open("data/biycle/test.txt", "w")

train_label_file = open("data/biycle/train_label.txt", "w")
test_label_file = open("data/biycle/test_label.txt", "w")

i =0
for name in class_names:
    path = root_path + name 
    file_names = get_filenames_in_folder(path)
    print(len(file_names))
    j =0
    for path in file_names:
        if (j%8 ==0):
            test_name_file.write(path + '\n')    
            test_label_file.write(str(i) + '\n')
        else:
            train_name_file.write(path + '\n')
            train_label_file.write(str(i) + '\n')
        j+=1
    i += 1

编写数据加载模块

class CustomImageDataset(Dataset):
    def __init__(self, data_path, model, transform=None, target_transform=None):
        self.data_path = data_path
        self.model = model
        self.img_labels = []
        self.image_lists =[]
        self.transform = transform
        self.target_transform = target_transform
        self.obtain_label_image()

    def __len__(self):
        return len(self.img_labels)


    def __getitem__(self, idx):
        #print(self.image_lists[idx])
        image = cv2.imread(self.image_lists[idx])
        image =cv2.resize(image, (32,32))
        label = self.img_labels[idx]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

    def obtain_label_image(self):
        if(self.model == "train"):
            # 指定文件夹路径
            folder_path = self.data_path + 'train.txt'
            with open(folder_path, 'r') as file:
                # 逐行读取文件内容
                for line in file:
                    self.image_lists.append(line.strip())

            file_path = self.data_path + 'train_label.txt'  # 替换为实际文件路径
            with open(file_path, 'r') as file:
                # 逐行读取文件内容
                for line in file:
                    # 处理每一行的数据,例如打印或存储
                    self.img_labels.append(int(line.strip()))  # 使用strip()方法去除行末的换行符
        if (self.model == "test"):
            folder_path = self.data_path + 'test.txt'
            with open(folder_path, 'r') as file:
                # 逐行读取文件内容
                for line in file:
                    self.image_lists.append(line.strip())

            file_path = self.data_path + 'test_label.txt'  # 替换为实际文件路径
            with open(file_path, 'r') as file:
                # 逐行读取文件内容
                for line in file:
                    # 处理每一行的数据,例如打印或存储
                    self.img_labels.append(int(line.strip()))  # 使用strip()方法去除行末的换行符

总结

当我们没有数据集的时候,使用爬虫技术获取数据集,我们能快速获取自己的数据集,构建自己的数据加载模块,使得我们能够胜任不同类型的数据加载。

关注我的公众号auto_driver_ai(Ai fighting), 第一时间获取更新内容。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值