第一种:通过pytorch中内置的ImageFolder()方法
以下面数据集为例:
# 1.用torch自带的ImageFolder()函数制作数据源
data_dir = r'D:/Projects/Datasets/flower_data/'
train_dir = data_dir + '/train'
valid_dir = data_dir + '/valid'
cat_to_name_file = r'D:/Projects/Datasets/flower_data/cat_to_name.json'
# 下面是'''for ResNet'''
data_transforms = {
'train': transforms.Compose([transforms.RandomRotation(45),#随机旋转,-45到45度之间随机选
transforms.CenterCrop(224),#从中心开始裁剪
transforms.RandomHorizontalFlip(p=0.5),#随机水平翻转 选择一个概率概率
transforms.RandomVerticalFlip(p=0.5),#随机垂直翻转
transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),#参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
transforms.RandomGrayscale(p=0.025),#概率转换成灰度率,3通道就是R=G=B
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])#均值,标准差
]),
'valid': transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
# 下面是'''for inception-v3'''
data_transforms1 = {
'train': transforms.Compose([
transforms.Resize(299),
transforms.CenterCrop(299),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
]),
'valid': transforms.Compose([
transforms.Resize(299),
transforms.CenterCrop(299),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
}
batch_size = 8
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'valid']}
dataloaders = {x: DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']}
class_names = image_datasets['train'].classes
#print(image_datasets)
print(dataloaders)
print(class_names)
print(dataset_sizes)
def showImage(tensor):
'''
function:展示数据(tensor转换成numpy)
tensor: 原tensor格式图片
return: numpy格式图片
'''
image = tensor.to("cpu").clone().detach()
image = image.numpy().squeeze()
image = image.transpose(1,2,0)
image = image*np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
image = image.clip(0, 1)
return image
fig = plt.figure(figsize=(20, 12))
columns = 4
rows = 2
with open(cat_to_name_file, 'r') as f:
cat_to_name = json.load(f)
print(cat_to_name)
dataiter = iter(dataloaders['valid'])
inputs, classes = next(iter(dataiter))
#print('inputs:', inputs)
print('classes:', classes)
for idx in range (columns*rows):
ax = fig.add_subplot(rows, columns, idx+1, xticks=[], yticks=[])
ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))])
plt.imshow(showImage(inputs[idx]))
plt.show()
第二种:通过自定义Dataloader()来处理数据
以下面数据集为例:
#2. 通过自定义dataloader
data_dir = r'D:/Projects/Datasets/flower_photos/'
data_transform = {
"train" : transforms.Compose([transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
"val": transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])}
class FlowerDataset(Dataset):
def __init__(self, root_dir, ann_file, transform):
# 数据集根目录路径和标签文件目录
self.root_dir = root_dir
self.ann_file = ann_file
self.img_label = self.load_annotations() # 经过这个函数操作,获得一个字典{image_path:label}
#self.img_label = {'sunflowers': 0, 'roses': 1, 'dandelion': 2, 'daisy': 3, 'tulips': 4, 'ddd':2}
#print(self.img_label)
# 将image的path单独放进一个list
self.img = [os.path.join(self.root_dir, img) for img in list(self.img_label.keys())]
self.lable = [label for label in list(self.img_label.values())]
self.transform = transform
# 返回一个数据和标签
def __getitem__(self, idx):
image = Image.open(self.img[idx])
label = self.lable[idx]
if self.transform:
image = self.transform(image)
label = torch.from_numpy(np.array(label)) # label是一个list格式,先转换成numpy格式,再转换成torch格式
return image, label
def __len__(self):
return len(self.img) # 通过img这个存放名字的列表来计算总共有多少个图片
#以字典格式制作一个数据。也可以给函数起名为:read_split_data
def load_annotations(self):
with open(self.ann_file, encoding='utf-8') as f: #with里定义的变量没有作用域,不要被缩进代码块误导
names_list = [x.strip().split(' ')[0] for x in f.readlines()] # 取到名字
#将f文件的指针复位到第一行,添加aa.seek(0)这行代码到第二次迭代前就可以啦。 ,否则后面打开文件内容为空
f.seek(0)
labels_list = [x.strip().split('/')[0] for x in f.readlines()] # 取到标签
# 这个label是一个字符串,把它转换成0,1,3,4,5对应。
# 将labels_list里面重复的字符串去掉,只保留一个,然后依次编号(这个操作可以用set()方法,转换成集合,在转换成list)
labels_num = list(set(labels_list))
labels_type = {name: labels_num.index(name) for name in
labels_num} # 将每个类添加数字类,方便做标签分类 # {'sunflowers': 0, 'roses': 1, 'dandelion': 2, 'daisy': 3, 'tulips': 4}
data_infos = {k: v for k, v in zip(names_list,
labels_list)} # {'daisy/7568630428_8cf0fc16ff_n.jpg': 'sunflowers', 'daisy/7410356270_9dff4d0e2e_n.jpg': 'roses',
data_infos = {k: labels_type[v] for k, v in data_infos.items()} # 将标签转换成对应数字
return data_infos
ann_file='D:/Projects/Datasets/flower_photos/annotations.txt'
train_dataset = FlowerDataset(root_dir='D:/Projects/Datasets/flower_photos', ann_file=ann_file, transform=data_transform["train"])
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
#测试是否打包好了
image, label = next(iter(train_loader))
sample = image[1].squeeze() # 把第一个维度去掉(1*3*224*224)变成(3*244*224)
sample = sample.permute((1,2,0)).numpy() # 变成numpy展示图像(224*224*3)
plt.imshow(sample)
plt.show()
print('lable is {}'.format(label[0].numpy()))
print(image.shape)