import torch
import os
import glob
import random
import csv
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
#加载自己的数据集
class Pokemon(Dataset):
#定义自己的主函数,函数内变量名:根目录,图像的规模,以及模式(训练,验证或者测试)
def __init__(self, root, resize, mode):
super(Pokemon, self).__init__()
#创建字典
self.name2label = {}
#连接两个或更多的路径名组件,如果有一个组件是一个绝对路径,则它之前的所有组件均会被舍弃
#os.listdir返回指定的文件夹或文件的名字的列表
for name in sorted(os.listdir((os.path.join(root)))):
if not os.path.isdir(os.path.join(root, name)):
continue
#将文件夹中是文件的文件名进行分类,按照字典中的键值对应规则,具体实现通过原先字典的长度,一开始为0,逐渐+1
self.name2label[name] = len(self.name2label.keys())
print(self.name2label)
self.root=root
self.resize=resize
self.load_csv('images.csv')
self.images, self.labels = self.load_csv('images.csv')
if mode == 'train':
self.images=self.images[:int(0.6*len(self.images))]
self.labels=self.labels[:int(0.6*len(self.labels))]
if mode=='val':
self.images=self.images[int(0.6*len(self.images)):int(0.8*len(self.images))]
self.labels=self.labels[int(0.6*len(self.labels)):int(0.8*len(self.labels))]
if mode=='test':
self.images=self.images[int(0.8*len(self.images)):int(len(self.images))]
self.labels=self.labels[int(0.8*len(self.labels)):int(len(self.labels))]
#编写load_csv函数,filename是images.csv
def load_csv(self, filename):
#创建images列表,保存个图片的路径
images = []
for name in self.name2label.keys():
#glob.glob查找符合特定规则的文件路径名
images+=glob.glob(os.path.join(self.root,name,'*.png'))
images+=glob.glob(os.path.join(self.root,name,'*.jpg'))
images+=glob.glob(os.path.join(self.root,name,'*.jpeg'))
random.shuffle(images)
# print(len(images),images[0])
#with open用来打开本地文件 CSV.WRITER写一个csv文件,
with open(os.path.join(self.root, filename), mode='w', newline='') as f:
writer= csv.writer(f)
for img in images:
name=img.split(os.sep)[-2]
label=self.name2label[name]
writer.writerow([img, label])
#print('writen in to filename', filename)
images, labels =[], []
with open(os.path.join(self.root, filename)) as f:
reader= csv.reader(f)
for row in reader:
img, label=row
label=int(label)
images.append(img)
labels.append(label)
assert len(images) == len(labels)
return images, labels
def __len__(self):
return len(self.images)
#图片格式的转化
def denomalize (self, x_hat):
mean = [0.485, 0.456, 0.406]
std = [0.229,0.224,0.225]
mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
x = x_hat * std + mean
return x
def __getitem__(self, idx):
img,label=self.images[idx], self.labels[idx]
tf = transforms.Compose([
lambda x:Image.open(x).convert('RGB'), #将路径名转化为图片数据类型
transforms.Resize((self.resize, self.resize)),
transforms.RandomRotation(0),
#以中心点按照原来的图片大小进行裁剪,操作过后,图片的大小跟原来的相同
#transforms.CenterCrop(self.resize),
transforms.ToTensor(),
#希望图片数据转换到0 1之间,但是图片转换之后会产生偏差,座椅还需要denormanization
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229,0.224,0.225])
])
#转换为图片类型后再转换为张量
img = tf(img)
label = torch.tensor(label)
return img, label
def main():
import time
import visdom
viz=visdom.Visdom()
db=Pokemon('D://face recogniton',224,'train') #实例化一个对象 人=>具体的一个人
#x,y = next(iter(db))
#print('sample',x.shape,y.shape)
#print(x,y)
#im = Image.open("D://360MoveData//Users//Gentle//Desktop//0.jpg") ##文件存在的路径
# # im.show()
# viz.image(x, win='sample_x',opts=dict(title='sample_x'))
# #viz.image(db.denomalize(y), win='sample_y',opts=dict(title='sample_y'))
loader = DataLoader(db, batch_size=32, shuffle=True)
for x,y in loader:
viz.images(db.denomalize(x), nrow=8, win='batch', opts=dict(title='batch'))
viz.text(str(y.numpy()), win='y-bay=tch', opts=dict(title='batch'))
time.sleep(5)
```python
最后通过main函数试运行,终端输入 python-m visdom.server
if name == ‘main’:
main()