作者用resnet50跑了一下standford cars数据集。该数据集共有196类,16185张图片,其中训练集有8144张,测试集有8041张图片。
训练集是这样的:
train/00001.jpg
train/00002.jpg
……
对应的标签放在另一个mat文件中。
读取mat文件的代码如下,将其写进txt文件,一行只有一个正整数,表示对应的一个标签
import scipy.io
data = scipy.io.loadmat('cars_train_annos.mat')
annotations = data['annotations']
f_train = open('./train.txt','w')
for i in range(annotations.shape[1]):
num = int(annotations[0,i][4])
num = str(num)
print(i,num)
f_train.write(num+'\n')
f_train.close()
有了图像和对应的标签之后,就可以开始写dataloader类了,代码如下:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
from torch.utils.data import Dataset
from torchvision import models,transforms
import os
import time
from PIL import Image
# use PIL Image to read iamge
def default_loader(path):
try:
img = Image.open(path)
return img.convert('RGB')
except:
print("Cannot read image:{}".format(path))
class customData(Dataset):
def __init__(self,img_path,txt_path,dataset='',data_transforms=None,loader=default_loader):
with open(txt_path) as input_file:
lines = input_file.readlines()
self.img_label = [int(line.strip()) for line in lines]
self.img_name = []
for root,dirs,files in os.walk(img_path):
for name in sorted(files):
self.img_name.append(os.path.join(img_path,name))
self.data_transforms = data_transforms
self.dataset = dataset
self.loader = loader
# 最主要的是将图片路径存入self.img_name,将对应的标签写入sel