步骤:
- 图片分类存储在不同文件夹下
- 写一个类继承自torch.utils.data.Dataset并重写__len()__和__getitem()__方法
- 打标签
- 写一个把图片路径与标签以”,“分隔存入csv文件,若文件存在能加载数据出来的方法
- __getitem()__方法把把csv的路径对应的图片读出来,进行转换,return,便于用torch.utils.data.Dataloader加载
import torch
import os, glob
import random, csv
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
class ADS_B(Dataset):
def __init__(self, root, resize, mode):
super(ADS_B, self).__init__()
self.root = root
self.resize = resize
self.name2label = {}
for name in sorted(os.listdir((os.path.join(root)))):
if not os.path.isdir(os.path.join(root, name)):
continue
self.name2label[name] = len(self.name2label.keys())
self.imagepaths, self.labels = self.load_csv('imagepaths.csv')
if mode == 'train':
self.imagepaths = self.imagepaths[:int(0.6 * len(self.imagepaths))]
self.labels = self.labels[:int(0.6 * len(self.labels))]
elif mode == 'val':
self.imagepaths = self.imagepaths[int(0.6 * len(self.imagepaths)):int(0.8 * len(self.imagepaths))]
self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.imagepaths))]
else:
self.imagepaths = self.imagepaths[int(0.8 * len(self.imagepaths)):]
self.labels = self.labels[int(0.8 * len(self.labels)):]
def load_csv(self, filename):
if not os.path.exists(os.path.join(self.root, filename)):
imagepaths = []
for name in self.name2label.keys():
imagepaths += glob.glob(os.path.join(self.root, name, '*.png'))
imagepaths += glob.glob(os.path.join(self.root, name, '*.jpg'))
random.shuffle(imagepaths)
with open(os.path.join(self.root, filename), mode='w', newline='') as f:
writer = csv.writer(f)
for imagepath in imagepaths:
name = imagepath.split(os.sep)[-2]
label = self.name2label[name]
writer.writerow([imagepath, label])
print('write into csv file:', filename)
imagepaths, labels = [], []
with open(os.path.join(self.root, filename)) as f:
reader = csv.reader(f)
for row in reader:
imagepath, label = row
label = int(label)
imagepaths.append(imagepath)
labels.append(label)
assert len(imagepaths) == len(labels)
return imagepaths, labels
def __len__(self):
return len(self.imagepaths)
def denormalize(self, x_hat):
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
x = x_hat * std + mean
return x
def __getitem__(self, idx):
imagepath, label = self.imagepaths[idx], self.labels[idx]
tf = transforms.Compose([
lambda x: Image.open(x).convert('RGB'),
transforms.Resize((self.resize, self.resize)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
image = tf(imagepath)
label = torch.tensor(label)
return image, label
def main():
import visdom
import time
viz = visdom.Visdom()
db = ADS_B('ADS-B', 224, 'train')
x, y = next(iter(db))
print('sample:', x.shape, y.shape, y)
viz.image(db.denormalize(x), win='sample_x', opts=dict(title='sample_x'))
loader = DataLoader(db, batch_size=32, shuffle=True)
print('len of loader:', len(loader))
for x, y in loader:
viz.images(db.denormalize(x), nrow=8, win='batch', opts=dict(title='batch'))
viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))
time.sleep(10)
if __name__ == '__main__':
main()