第一步:继承DataSet并创建自己的PaddyDataSet---一般是重写__init__()方法,__get_item()__方法,补充__len()__方法。
# PaddyDataSet
import torch
from torch.utils.data import Dataset
from PIL import Image
import os
import numpy as np
from torch.utils.data import DataLoader
import numpy as np
import torchvision.transforms as transforms
paddy_labels = {'bacterial_leaf_blight':0,'bacterial_leaf_streak':1,'bacterial_panicle_blight':2,'blast':3,'brown_spot':4,
'dead_heart':5, 'downy_mildew':6, 'hispa':7, 'normal':8, 'tungro':9}
class PaddyDataSet(Dataset):
def __init__(self, data_dir,transform=None):
"""
数据集
"""
self.label_name={'bacterial_leaf_blight':0,'bacterial_leaf_streak':1,'bacterial_panicle_blight':2,'blast':3,'brown_spot':4,