transform_test = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
读取数据
====
将数据集解压后放到data文件夹下面,如图:
然后我们在dataset文件夹下面新建 __init__.py和dataset.py,在dataset.py文件夹写入下面的代码:
说一下代码的核心逻辑。
第一步 建立字典,定义类别对应的ID,用数字代替类别。
第二步 在__init__里面编写获取图片路径的方法。测试集只有一层路径直接读取,训练集在train文件夹下面是类别文件夹,先获取到类别,再获取到具体的图片路径。然后使用sklearn中切分数据集的方法,按照7:3的比例切分训练集和验证集。
第三步 在__getitem__方法中定义读取单个图片和类别的方法,由于图像中有位深度32位的,所以我在读取图像的时候做了转换。
coding:utf8
import os
from PIL import Image
from torch.utils import data
from torchvision import transforms as T
from sklearn.model_selection import train_test_split
Labels = {‘Black-grass’: 0, ‘Charlock’: 1, ‘Cleavers’: 2, ‘Common Chickweed’: 3,
‘Common wheat’: 4, ‘Fat Hen’: 5, ‘Loose Silky-bent’: 6, ‘Maize’: 7, ‘Scentless Mayweed’: 8,
‘Shepherds Purse’: 9, ‘Small-flowered Cranesbill’: 10, ‘Sugar beet’: 11}
class SeedlingData (data.Dataset):
def init(self, root, transforms=None, train=True, test=False):
“”"
主要目标: 获取所有图片的地址,并根据训练,验证,测试划分数据
“”"
self.test = test
self.transforms = transforms
if self.test:
imgs = [os.path.join(root, img) for img in os.listdir(root)]
self.imgs = imgs
else:
imgs_labels = [os.path.join(root, img) for img in os.listdir(root)]
imgs = []
for imglable in imgs_labels:
for imgname in os.listdir(imglable):
imgpath = os.path.join(imglable, imgname)
imgs.append(imgpath)
trainval_files, val_files = train_test_split(imgs, test_size=0.3, random_state=42)
if train:
self.imgs = trainval_files
else:
self.imgs = val_files
def getitem(self, index):
“”"
一次返回一张图片的数据
“”"
img_path = self.imgs[index]
img_path=img_path.replace(“\”,‘/’)
if self.test:
label = -1
else:
la