1.定义自己的数据集
#1.处理数据,定义自己的数据集
import torch as t
from torch.utils import data
import os
from PIL import Image
import numpy as np
class DogCat(data.Dataset):
def __init__(self,root):
imgs=os.listdir(root)
#所有图片的绝对路径
#这里不实际加载图片,只是指定路径
#当调用__getitem__时才会真正读图片
self.imgs=[os.path.join(root,img) for img in imgs]
def __getitem__(self, index):
img_path=self.imgs[index]
# dog ->1,cat ->0
label=1 if 'dog' in img_path.split('/')[-1] else 0
pil_img=Image.open(img_path)
array=np.asarray(pil_img)
data=t.from_numpy(array)
return data,label
def __len__(self):
return len(self.imgs)
dataset=DogCat('C:\\Users\\29282\\Desktop\\代码\\kaggle\\train\\train')
img,label=dataset[0]
for img,label in dataset:
print(img.size(),img.float().mean(),label)
利用torchvision工具处理图像
import os
from PIL import Image
from torch.utils import data
import numpy as np
from torchvision import transforms as T
transform=T.Compose([
T.Resize(224),#缩放图片,保持长宽比不变,最短边为224像素
T.CenterCrop(224),#从图片中间切出224*224的图片
T.ToTensor(),#将图片转成Tensor,归一化至[0,1]
T.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5])#标准化至[-1,1]
])
class DogCat(data.Dataset):
def __in