读取txt文件中的标签信息
加载图片数据存放的位置
import os
import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image
train_dir = 'train_filelist'
valid_dir = 'val_filelist'
具体见代码注释
from torch.utils.data import Dataset, DataLoader
class FlowerDataset(Dataset):
def __init__(self, root_dir, ann_file, transform=False):
self.ann_file = ann_file
self.root_dir = root_dir
self.img_label = self.load_annotations()
self.img = [os.path.join(self.root_dir, img) for img in list(self.img_label.keys())] #将文件名和img结合起来成为路径 存在img
self.label = [label for label in list(self.img_label.values())] #将字典中values值取出放在label中
self.transform = transform #数据预处理
def __len__(self):
return len(self.img)
def __getitem__(self, idx): #idx=index 索引
image = Image.open(self.img[idx])
label = self.label[idx]
if self.transform:
image = self.transform(image) #如果需要处理,就先处理 变成tensor形式
label = torch.from_numpy(np.array(label)) #将lablel转换成张量形式
return image, label
def load_annotations(self):
data_infos = {}
with open(self.ann_file) as f:
samples = [x.strip().split(' ') for x in f.readlines()] #读取标签文件 空格作为分隔符 一条一条读取 放在sample里 此时sample里有 文件名 和标签 list形式
for filename, gt_label in samples: #从list中读取 文件名和标签 存在data_info中 字典形式, 一个key 对应一个value
data_infos[filename] = np.array(gt_label, dtype=np.int64)
return data_infos #返回字典结构
train_dataset = FlowerDataset(root_dir=train_dir, ann_file = 'train.txt', transform=data_transforms['train_filelist'])
val_dataset = FlowerDataset(root_dir=valid_dir, ann_file = 'val.txt', transform=data_transforms['val_filelist'])
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=True)
dataiter_train=iter(train_loader)
image, label = next(dataiter_train)