代码以及数据集下载:https://github.com/duchp/python-all/tree/master/CV%20code/动物多分类项目
一、任务介绍
- 纲分类问题,预测该动物是属于哺乳纲(Mammals)还是鸟纲(Birds)
- 种分类问题,预测该动物是兔子、老鼠还是鸡
- 多任务分类问题,同时预测动物的“纲”和“种”
二、数据预处理
使用的数据集结构如下:
上图为文件夹数据结构,同时,我们对于每一个任务,都有两个.csv文件,用于存储对应的训练数据与测试数据的标签。
训练的第一步是将数据存为pytorch里面的Dataset类,便于后续处理和读取数据。
因为三个任务的数据一致,不同的是标签情况,在代码上差异不大,所以以第一个任务——纲分类为例,数据存为Dataset类的代码及注释如下:
class MyDataset(torch.utils.data.Dataset):
def __init__(self,root,transform=None):
super(MyDataset,self).__init__()
#读取csv文件
file_info = pd.read_csv(root, index_col=0)
#读取图像路径
file_path = file_info['path']
#读取图像纲分类标签
file_class = file_info['classes']
imgs = []
imglb = []
#依次处理数据
for i in range(len(file_path)):
path = file_path[i]
path = path.replace('\\','/')
if not os.path.isfile(path):
print(path + ' does not exist!')
return None
#读对应路径的图像,存为img,依次存入imgs
img = Image.open(path).convert('RGB')
imgs.append(img)
#读对应图像的标签,依次存入imglb
imglb.append(int(file_class[i]))
self.image = imgs
self.imglb = imglb
self.root = root
self.size = len(file_info)
self.transform = transform
def __getitem__(self,index):
img = self.image[index]
label = self.imglb[index]
sample = {'image': img,'classes':label}
if self.transform:
sample['image'] = self.transform(img)
return sample
def __len__(self):
return self.size
存好数据后,开始为训练网络做准备,在深度学习训练中,一般会对数据进行一些变换,比如:翻转,旋转等等,用来增强鲁棒性,因此,这里需要定义一个数据变换函数:
train_transforms = transforms.Compose([transforms.Resize((500, 500)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
])
val_transforms = transforms.Compose([transforms.Resize((500, 500)),
transforms.ToTensor()
])
这里我们只做了训练图像的resize和随机水平翻转变换,实际上图像变换还有很多选择,具体可查询torchvision.transforms.Compose这个函数的参数设置。这个设定是要赋值到上面的MyDataset里面的,可以看到上面的transform默认为None,如果不赋值,则不做变换。
到了这一步,我们的数据还需要一个非常重要的处理,那就是分batch,因为如果一次性把数据放入训练中,往往内存是不够的,如果一张一张的训练,每次修正方向以