网上有很多用自己的数据集制作类似CIFAR10的数据集方法,但制作完了data_batch_1文件以后,想要直接训练自己的数据中间还差了将数据集文件制作成网络所需要的dataset的过程,我在网上找是没找到,如果有找到的还请告知一下,将自己的数据集制作成类似data_batch_1的文件参考https://blog.csdn.net/weixin_40437821/article/details/103603700
从data_batch_1到dataset:
类比torchvision.datasets中的cifar.py编写了通过自己制作的data_batch_train(相当于cifar10数据集中的data_batch_1)转化成可读的数据:
对于cifar10数据集的data_batch_1,源程序是先将cifar10.py中的CIFAR10进行实例化代码为:
module = getattr(torchvision.datasets, 'CIFAR10')
train_dataset = module(config.dataset.dataset_dir, #官网解释:cifar-10-batches-py文件夹的存放位置,如果为空则从官网直接下载到该地址处,
train=is_train, #bool值,为True则提取module中的所有train的数据
transform=train_transform, #一系列数据增强
download=True) #控制着是否从官网下载数据集
其中,torchvision.datasets相当于torchvision.datasets.cifar, module继承的是torchvision.datasets.cifar中的CIFAR10类
将其改为:
module2 = getattr(solar, 'SOLAR5')
train_dataset = module2(config.dataset.dataset_dir,
train=True,
transform=train_transform,
target_transform=None)
因为我是使用的太阳能电池板数据集,所以此处我以solar.py来取代的cifar.py,相对于以上,我这里则是module2继承的solar函数中的’SOLAR5’函数,下面则是对CIFAR5函数具体的改法:
文件名:solar.py
from PIL import Image
import os
import os.path
import numpy as np
# pickle是二进制序列化格式,人工不可读,特定于python
import pickle
import hashlib
from torchvision.datasets.vision import VisionDataset
class SOLAR5(VisionDataset):
base_folder = 'solar-5-batches-py'
train_list = [['data_batch_train', '0715171a5755bc64941b47a0ab3b8f72'],]
test_list = [['data_batch_val', '59d0df1891e62397d65f6fcf903ba2d7']]
def __init__(self, root, train=True, transform=None, target_transform=None):
super(SOLAR5, self).__init__(root, transform=transform,
target_transform=target_transform)
self.train = train
if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.' +
'You can use download=True to download it')
if self.train:
downloaded_list = self.train_list
else:
downloaded_list = self.test_list
self.data = []
self.targets = []
#now load the picked numpy arrays 遍历downloaded_list中所有的batch
for file_name, checknum in downloaded_list:
file_path = os.path.join(self.root, self.base_folder, file_name)
with open(file_path, 'rb') as f:
entry = pickle.load(f, encoding='latin1') #解码,pickle中调用load()函数将pickle中的数据流进行反序列化(人工可读格式)
self.data.append(entry['data']) #将pickle文件中的data存入一个list中,遍历完,data[]一个list中包含5个小的array[]
if 'labels' in entry:
self.targets.extend(entry['labels']) #将pickle文件中的labels存入一个list中,targets一个list中包含50000个标签
else:
self.targets.extend(entry['fine_labels'])
self.data = np.vstack(self.data).reshape(-1, 3, 120, 120)
self.data = self.data.transpose((0, 2, 3, 1))
# self._load_meta()
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], self.targets[index]
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.data)
# def _load_meta(self):
# path = os.path.join(self.root, self.base_folder, self.meta['filename'])
# if not check_integrity(path, self.meta['md5']):
# raise RuntimeError('Dataset metadata file not found or corrupted.' +
# 'You can use download=True to download it')
# with open(path, 'rb') as infile:
# data = pickle.load(infile, encoding='latin1') #将batches_meta解码
# self.classes = data[self.meta['key']]
# self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)}
def _check_integrity(self):
root = self.root
for fentry in (self.train_list + self.test_list):
filename, md5 = fentry[0], fentry[1]
fpath = os.path.join(root, self.base_folder, filename)
if not check_integrity(fpath, md5):
return False
return True
#来自torchvision.datasets.utils.py
def check_integrity(fpath, md5=None):
if not os.path.isfile(fpath):
return False
if md5 is None:
return True
return check_md5(fpath, md5)
def check_md5(fpath, md5, **kwargs):
return md5 == calculate_md5(fpath, **kwargs)
# 通过摘要算法(哈希算法)得到一个长度固定的数据块,以便进行存储
def calculate_md5(fpath, chunk_size=1024 * 1024):
md5 = hashlib.md5() #哈希算法,获取一个md5加密算法对象
with open(fpath, 'rb') as f:
for chunk in iter(lambda: f.read(chunk_size), b''):
md5.update(chunk) #制定需要加密的字符串
return md5.hexdigest() #获取加密后的16进制字符串
其中train_list = [[‘data_batch_train’, ‘0715171a5755bc64941b47a0ab3b8f72’],]在制作的时候涉及到了md5编码,以下是制作md5编码的函数:
#调用torchversion.datasets.utils.py文件中的calculate函数来计算pickle文件的md5编码
import hashlib
def calculate_md5(fpath, chunk_size=1024 * 1024):
md5 = hashlib.md5()
with open(fpath, 'rb') as f:
for chunk in iter(lambda: f.read(chunk_size), b''):
md5.update(chunk)
return md5.hexdigest()
path = 'Dataset/SOLAR/solar-5-batches-py/data_batch_val'
if __name__ == '__main__':
aaa = calculate_md5(path, chunk_size=1024 * 1024)
print(aaa)
path为数据集的地址,每个数据集都对应一个md5编码
然后就成功读取pickle数据集啦!!!!
读取成功显示:
Dataset SOLAR5
Number of datapoints: 529
Root location: Dataset/SOLAR
StandardTransform
Transform: Compose(
<pytorch_image_classification.transforms.transforms.RandomCrop object at 0x0000022167D48A58>
<pytorch_image_classification.transforms.transforms.RandomHorizontalFlip object at 0x0000022162D18358>
<pytorch_image_classification.transforms.transforms.Normalize object at 0x0000022162D18438>
<pytorch_image_classification.transforms.transforms.ToTensor object at 0x0000022162D184A8>
)
Process finished with exit code 0