目录
数据集类的构建
本实例使用的是UTKFace数据集,包含了两万多张不同种族的不同年龄的人脸图片
torch.utils.data.Dataset
是一个抽象类, 自定义的Dataset需要继承它并且实现两个成员方法:
- getitem()
- len()
第一个最为重要,即每次怎么读数据;
第二个比较简单, 就是返回整个数据集的长度。
implementation
导入包
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import os
import cv2 as cv
_init_()
torchvision.transforms.Normalize(mean, std) 用法
torchvision.transforms.ToTensor() 用法
os.listdir() 用法
.split() 用法
os.path() 模块
class AgeGenderDataset(Dataset):
def __init__(self, root_dir):
# Normalize: image => [-1, 1] (利于更好的训练)
# ToTensor() => Converts a PIL Image or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
self.transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5]),
transforms.Resize((64, 64))
])
img_files = os.listdir(root_dir) #存放的是所有图片的文件名
# ag