1.作业简介:猫狗大战
1.1问题描述:
在这个问题中,你将面临一个经典的机器学习分类挑战——猫狗大战。你的任务是建立一个分类模型,能够准确地区分图像中是猫还是狗。
1.2预期解决方案:
你的目标是通过训练一个机器学习模型,使其在给定一张图像时能够准确地预测图像中是猫还是狗。模型应该能够推广到未见过的图像,并在测试数据上表现良好。我们期待您将其部署到模拟的生产环境中——这里推理时间和二分类准确度(F1分数)将作为评分的主要依据。
1.3数据集
数据集:
链接:百度网盘 请输入提取码
提取码:jc34
1.4图像展示
2.数据预处理
2.1数据集结构
本项目数据集只有一个train数据集,我将其切分三部分,test,train,val文件夹。
其中cat文件夹下包含了猫的图片,dog文件夹下包含了狗的图片。
test集中包含的数据:
train和val集中包含的数据量:
2.2探索性数据分析
在这里,我分别取了train数据集下的随机不重样10只狗,10只猫的图像进行展示:
import matplotlib.pyplot as plt
import os
from PIL import Image
cat_dir = './train/cat'
dog_dir = './train/dog'
cat_image_list = os.listdir(cat_dir)
dog_image_list = os.listdir(dog_dir)
show_image = [os.path.join(cat_dir, cat_image_list[i]) for i in range(10)]
show_image.extend([os.path.join(dog_dir, dog_image_list[i]) for i in range(10)])
plt.figure()
for i in range(1, 21):
plt.subplot(4, 5, i)
img = Image.open(show_image[i-1])
plt.imshow(img)
plt.show()
2.3提取数据集
在本项目中,为了更好地提取出图像,我构建了一个函数,能够将每个主文件夹下的图片提取出来,并且打好了标签。
# 创建自定义数据集
class SelfDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.classes = ['cat', 'dog']
self.data = self.load_data()
def load_data(self):
data = []
for class_idx, class_name in enumerate(self.classes):
class_path = os.path.j