需打开Visdom:python -m visdom.server
from torch.utils.data import Dataset, DataLoader
import torch
import os, glob
import random, csv
# 图片读取工具
from torchvision import transforms
from PIL import Image
class NumberDataset1(Dataset):
def __init__(self, root, resize, mode):
super(NumberDataset1, self).__init__()
self.root = root
self.resize = resize
self.name2label = {}
for name in sorted(os.listdir(os.path.join(root))): ##保持顺序不乱
if not os.path.isdir(os.path.join(root, name)):
continue
self.name2label[name] = len(self.name2label.keys()) ##把key的长度作为名字
print(self.name2label)
self.images, self.labels = self.load_csv('images1.csv')
if mode == 'train':
self.images = self.images[:int(0.6 * len(self.images))]
self.labels = self.labels[:int(0.6 * len(self.labels))]
elif mode == 'val':
self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.labels))]
self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))]
else:
self.images = self.images[int(0.8 * len(self.images)):]
self.labels = self.labels[int(0.8 * len(self.labels)):]
def load_csv(self, filename):
if not os.path.exists(os.path.join(self.root, filename)):
images = []
for name in self.name2label.keys():
images += glob.glob(os.path.join(self.root, name, '*.png'))
images += glob.glob(os.path.join(self.root, name, '*.jpg'))
images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
print(len(images), images)
random.shuffle(images) # 打乱
with open(os.path.join(self.root, filename), mode='w', newline='') as f:
writer = csv.writer(f)
for img in images:
name = img.split(os.sep)[-2] # 用\\分开然后去倒数第二个
label = self.name2label[name]
writer.writerow([img, label])
images, labels = [], []
with open(os.path.join(self.root, filename)) as f:
reader = csv.reader(f)
for row in reader:
# print(row)
img, label = row
label = int(label)
images.append(img)
labels.append(label)
assert len(images) == len(labels)
return images, labels
# if training:
# self.samples = list(range(1,1001))
# else:
# self.samples = list(range(1000,1501))
def __len__(self):
return len(self.images)
def __getitem__(self, item):
img, label = self.images[item], self.labels[item]
####### 将img从img_path转换成img data tensor,并进行 data argumentation #################
tf = transforms.Compose([
lambda x: Image.open(x).convert('RGB'), # string path= > image data 将img的path转换成img data
transforms.Resize((int(self.resize * 1.25), int(self.resize * 1.25))),
transforms.RandomRotation(15),
transforms.CenterCrop(self.resize), # 先变成resize的1.25倍大,然后随机旋转15度,再中心裁剪到resize大小
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
img = tf(img)
label = torch.tensor(label)
return img, label
def denormalize(self, x_hat):
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
# x_hat = (x-mean)/std
# x = x_hat*std + mean
# x: [c, h, w]
# mean: [3] => [3, 1, 1]
# std:[3] => [3, 1, 1]
mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1) # 增加维度
std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
# print(mean.shape, std.shape)
x = x_hat * std + mean
return x
def main():
pass
if __name__ == '__main__':
import visdom
import time
import torchvision
vis = visdom.Visdom()
tf = torchvision.transforms.Compose([
transforms.Resize((64,64)),
transforms.ToTensor()
])
###第一种方式
db = torchvision.datasets.ImageFolder(root='pokemon',transform=tf)
loader = DataLoader(db, batch_size=32, shuffle=True,num_workers=4) ##8个线程
for x, y in loader:
vis.images(x, nrow=8, win='batch', opts=dict(title='batch')) # 每行8行
vis.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))
time.sleep(10) # 每次加载完后休息10秒
###第2种方式
# db = NumberDataset1('pokemon', 64, 'train')
# x, y = next(iter(db))
# print(x.shape, y.shape, y)
# # viz.image(x, win='sample_x', opts=dict(title='sample_x')) #显示这张图片,显示不全,要进行denormalize
# vis.image(db.denormalize(x), win='sample_x', opts=dict(title='sample_x'))
# loader = DataLoader(db, batch_size=32, shuffle=True)
# for x, y in loader:
# vis.images(db.denormalize(x), nrow=8, win='batch', opts=dict(title='batch')) # 每行8行
# vis.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))
#
# time.sleep(10) # 每次加载完后休息10秒
效果图: