前言
迁移学习是一种机器学习方法,它利用已经训练好的模型在新任务上进行训练,从而提高模型的性能和泛化能力。在本文中,我们将使用PyTorch实现一个基于预训练模型的迁移学习模型,用于单车分类识别。
项目概述
我们的目标是创建一个能够识别不同类型自行车的图像分类模型。为实现这一目标,我们首先需要获取一个包含大量自行车图片的数据集。由于公开可用的数据集可能不完全满足特定需求,我们决定使用爬虫技术从互联网上抓取自行车图片。
数据集爬取
使用Python的requests库构建网络爬虫,从网页中提取图片。
def get_images_from_baidu(keyword, page_num, save_dir):
header = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/78.0.3904.108 Safari/537.36'}
# 请求的 url
url = 'https://image.baidu.com/search/acjson?'
n = 0
for pn in range(0, 30 * page_num, 30):
# 请求参数
param = {'tn': 'resultjson_com',
'logid': '7603311155072595725',
'ipn': 'rj',
'ct': 201326592,
'is': '',
'fp': 'result',
'queryWord': keyword,
'cl': 2,
'lm': -1,
'ie': 'utf-8',
'oe': 'utf-8',
'adpicid': '',
'st': -1,
'z': '',
'ic': '',
'hd': '',
'latest': '',
'copyright': '',
'word': keyword,
's': '',
'se': '',
'tab': '',
'width': '',
'height': '',
'face': 0,
'istype': 2,
'qc': '',
'nc': '1',
'fr': '',
'expermode': '',
'force': '',
'cg': '', # 这个参数没公开,但是不可少
'pn': pn, # 显示:30-60-90
'rn': '30', # 每页显示 30 条
'gsm': '1e',
'1618827096642': ''
}
request = requests.get(url=url, headers=header, params=param)
if request.status_code == 200:
print('Request success.')
request.encoding = 'utf-8'
# 正则方式提取图片链接
html = request.text
image_url_list = re.findall('"thumbURL":"(.*?)",', html, re.S)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
for image_url in image_url_list:
image_data = requests.get(url=image_url, headers=header).content
with open(os.path.join(save_dir, f'{n:06}.jpg'), 'wb') as fp:
fp.write(image_data)
n = n + 1
我们将单车分成5类:hello(哈罗单车), meituan(美团单车), qingju(青桔单车),zijiadanche(自用单车),qita(其它单车)
数据清理
获取到图片后,需要对数据进行清洗,去除不相关或质量不高的图片进行过滤。
对数据进行标注
针对于不同的类别,标注数据集。
train_name_file = open("data/biycle/train.txt", "w")
test_name_file = open("data/biycle/test.txt", "w")
train_label_file = open("data/biycle/train_label.txt", "w")
test_label_file = open("data/biycle/test_label.txt", "w")
i =0
for name in class_names:
path = root_path + name
file_names = get_filenames_in_folder(path)
print(len(file_names))
j =0
for path in file_names:
if (j%8 ==0):
test_name_file.write(path + '\n')
test_label_file.write(str(i) + '\n')
else:
train_name_file.write(path + '\n')
train_label_file.write(str(i) + '\n')
j+=1
i += 1
编写数据加载模块
class CustomImageDataset(Dataset):
def __init__(self, data_path, model, transform=None, target_transform=None):
self.data_path = data_path
self.model = model
self.img_labels = []
self.image_lists =[]
self.transform = transform
self.target_transform = target_transform
self.obtain_label_image()
def __len__(self):
return len(self.img_labels)
def __getitem__(self, idx):
#print(self.image_lists[idx])
image = cv2.imread(self.image_lists[idx])
image =cv2.resize(image, (32,32))
label = self.img_labels[idx]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
def obtain_label_image(self):
if(self.model == "train"):
# 指定文件夹路径
folder_path = self.data_path + 'train.txt'
with open(folder_path, 'r') as file:
# 逐行读取文件内容
for line in file:
self.image_lists.append(line.strip())
file_path = self.data_path + 'train_label.txt' # 替换为实际文件路径
with open(file_path, 'r') as file:
# 逐行读取文件内容
for line in file:
# 处理每一行的数据,例如打印或存储
self.img_labels.append(int(line.strip())) # 使用strip()方法去除行末的换行符
if (self.model == "test"):
folder_path = self.data_path + 'test.txt'
with open(folder_path, 'r') as file:
# 逐行读取文件内容
for line in file:
self.image_lists.append(line.strip())
file_path = self.data_path + 'test_label.txt' # 替换为实际文件路径
with open(file_path, 'r') as file:
# 逐行读取文件内容
for line in file:
# 处理每一行的数据,例如打印或存储
self.img_labels.append(int(line.strip())) # 使用strip()方法去除行末的换行符
总结
当我们没有数据集的时候,使用爬虫技术获取数据集,我们能快速获取自己的数据集,构建自己的数据加载模块,使得我们能够胜任不同类型的数据加载。
关注我的公众号auto_driver_ai(Ai fighting), 第一时间获取更新内容。