图片数据封装时候往往需要用到transform, 这个方法一方面将图片数据shape成【channel, w, h】,另一方卖弄将图片array转化为tensor
class CatsAndDogsDataset(Dataset):
def __init__(self, csv_file, root_dir, transform=None):
self.annotations = pd.read_csv(csv_file)
self.root_dir = root_dir
self.transform = transform
def __len__(self):
return len(self.annotations)
def __getitem__(self, index):
img_path = os.path.join(self.root_dir, self.annotations.iloc[index, 0])
image = io.imread(img_path) #array [h,w,channel]
y_label = torch.tensor(int(self.annotations.iloc[index, 1]))
if self.transform:
image = self.transform(image) #tensor [channel, h,w]
return (image, y_label)
上面这段代码实例化数据集时候需要传入transform方法
dataset = CatsAndDogsDataset(
csv_file='./custom_dataset/cats_dogs.csv',
root_dir='./custom_dataset/cats_dogs_resized',
transform=transforms.ToTensor(),
)
需要先导入import torchvision.transforms as transforms
其中image = io.imread(img_path)
读取的image格式是array, shape为[h,w,channel] , 但是网络训练时候一般要求是[channel, h, w], 所以image = self.transform(image)
将其转化为tensor, shape为 [channel, h,w]格式