转载来源:CSDN
原文:https://blog.csdn.net/Hungryof/article/details/76649006
torchvision的主要用途。
两种数据集:
- 所有图片都在同一个文件夹内。(这个用 torch.utils.data.DataSet类就行!)
- 不同类别的图片放在不同的文件夹。(用 torchvision.datasets.ImageFolder(‘image_dir_root’ )
大部分任务的数据都是第一种吧,第二种一般是分类任务,比如imagenet数据集有1000类,对应1000个文件夹。
目录结构如下:
root/ants/xxx.png
root/ants/xxy.jpeg
root/ants/xxz.png
.
.
.
root/bees/123.jpg
root/bees/nsdf3.png
root/bees/asd932_.png
注意:
torchvision包的三个用途:
- 提供流行的model,同时可以针对常用数据集直接进行处理。
- 还针对torch.utils.data.Dataset进行了扩充,主要就是有了针对这种不同类别图片放入不同文件夹的数据进行读取,torchvision.datasets.ImageFolder是torch.utils.data.Dataset的子类!都返回一个迭代器。
- 提供现成的torchvision.transforms ,从而避免自己写的麻烦。
两种读取方法
一般用到:
- torch.utils.data.Dataset(这是底层的),或是继承自它的自定义类,或是继承自它的 torchvision.data.ImageFolder.
- 对于1读取的图片,进行 torchvison.transforms来变换一下。
- 对于2返回的迭代器,用 torch.utils.data.DataLoader用多线程读取。
读取流程示意
-
自定义dataset类, 它是最底层的。重载 torch.utils.data.Dataset。至少重载三个函数:
init, getitem__以及__len.
这个主要负责从数据库中读取图片,但是我们读取的图片可能要经过各种变换,放缩之类的。所以在_init__中可以把变换操作名称传入,在_getitem 中先load图片,然后在img_transformed = self.transforms(img)。其中self.transforms是__init__传入的参数。 -
将torchvision.transforms.Compose函数作为参数,往自定义dataset类里面传
-
将2返回的迭代器,用 torch.utils.data.DataLoader多线程读取
使用 torch.utils.data.Dataset针对 All images in One Folder
以官方例子 super_resolution为例:
首先在main中
train_set = get_training_set(opt.upscale_factor)
test_set = get_test_set(opt.upscale_factor)
training_data_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True)
testing_data_loader = DataLoader(dataset=test_set, num_workers=opt.threads, batch_size=opt.testBatchSize, shuffle=False)
然后看 get_training_set,追踪到data.py,该脚本主要是对数据进行下载解压,以及
from os.path import exists, join, basename
from os import makedirs, remove
from six.moves import urllib
import tarfile
from torchvision.transforms import Compose, CenterCrop, ToTensor, Scale
from dataset import DatasetFromFolder
def download_bsd300(dest="dataset"):
output_image_dir = join(dest, "BSDS300/images")
if not exists(output_image_dir):
makedirs(dest)
url = "http://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/BSDS300-images.tgz"
print("downloading url ", url)
data = urllib.request.urlopen(url)
file_path = join(dest, basename(url))
with open(file_path, 'wb') as f:
f.write(data.read())
print("Extracting data")
with tarfile.open(file_path) as tar:
for item in tar:
tar.extract(item, dest)
remove(file_path)
return output_image_dir
def calculate_valid_crop_size(crop_size, upscale_factor):
return crop_size - (crop_size % upscale_factor)
def input_transform(crop_size, upscale_factor):
return Compose([
CenterCrop(crop_size),
Scale(crop_size // upscale_factor),
ToTensor(),
])
def target_transform(crop_size):
return Compose([
CenterCrop(crop_size),
ToTensor(),
])
看到这里开始调用自定义dataset类!
def get_training_set(upscale_factor):
root_dir = download_bsd300()
train_dir = join(root_dir, "train")
crop_size = calculate_valid_crop_size(256, upscale_factor)
自定义dataset类,传入参数是 transforms。可以看到这是将函数input_transform作为
参数传进自定义类。
return DatasetFromFolder(train_dir,
input_transform=input_transform(crop_size, upscale_factor),
target_transform=target_transform(crop_size))
def get_test_set(upscale_factor):
root_dir = download_bsd300()
test_dir = join(root_dir, "test")
crop_size = calculate_valid_crop_size(256, upscale_factor)
return DatasetFromFolder(test_dir,
input_transform=input_transform(crop_size, upscale_factor),
target_transform=target_transform(crop_size))
再找到 dataset.py, 这里开始自定义dataset类。
import torch.utils.data as data
from os import listdir
from os.path import join
from PIL import Image
def is_image_file(filename):
return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"])
def load_img(filepath):
img = Image.open(filepath).convert('YCbCr')
y, _, _ = img.split()
return y
class DatasetFromFolder(data.Dataset):
def __init__(self, image_dir, input_transform=None, target_transform=None):
super(DatasetFromFolder, self).__init__()
self.image_filenames = [join(image_dir, x) for x in listdir(image_dir) if is_image_file(x)]
self.input_transform = input_transform
self.target_transform = target_transform
在__getitem__中加载图片,并且将传入的transformation操作运用到
加载的图片中。 input = self.input_transforms(input)
这里的 self.input_transforms就是传入的"类的实例",由于类是callable的
所以可以 "类的实例(参数)"这样调用。在上一篇博客说到了这个。
def __getitem__(self, index):
input = load_img(self.image_filenames[index])
target = input.copy()
if self.input_transform:
input = self.input_transform(input)
if self.target_transform:
target = self.target_transform(target)
return input, target
def __len__(self):
return len(self.image_filenames)
看看torchvision.data.MNIST内部
class MNIST(data.Dataset):
"""`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.
Args:
root (string): Root directory of dataset where ``processed/training.pt``
and ``processed/test.pt`` exist.
train (bool, optional): If True, creates dataset from ``training.pt``,
otherwise from ``test.pt``.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""
urls = [
'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',
'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz',
'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz',
'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz',
]
raw_folder = 'raw'
processed_folder = 'processed'
training_file = 'training.pt'
test_file = 'test.pt'
def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
self.root = os.path.expanduser(root)
self.transform = transform
self.target_transform = target_transform
self.train = train # training set or test set
if download:
self.download()
if not self._check_exists():
raise RuntimeError('Dataset not found.' +
' You can use download=True to download it')
if self.train:
self.train_data, self.train_labels = torch.load(
os.path.join(root, self.processed_folder, self.training_file))
else:
self.test_data, self.test_labels = torch.load(os.path.join(root, self.processed_folder, self.test_file))
可以看到,这里也是用 img = self.transform(img)
方式的。
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
if self.train:
img, target = self.train_data[index], self.train_labels[index]
else:
img, target = self.test_data[index], self.test_labels[index]
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img.numpy(), mode='L')
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
if self.train:
return len(self.train_data)
else:
return len(self.test_data)
…
使用 torchvision.data.ImageFolder针对 One kind of images in One kind of Folder
比如imagenet的代码:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
第一,二步
用ImageFolder来读取dataset
train_dataset = datasets.ImageFolder(
traindir,
transforms.Compose([
transforms.RandomSizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]))
…
第三步
DataLoader多线程读取
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_sampler)
val_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(valdir, transforms.Compose([
transforms.Scale(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])),
batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True)