1.如何自定义数据集:
- 1.数据和标签的目录结构先搞定
- 2.写好读取数据和标签路径的函数(根据自己数据集情况来写)
- 3.完成单个数据与标签读取函数
我们以用txt文件指定数据路径与标签为实际例子
数据集情况,train_filelist存入训练集图片,val_filelist存入验证集图片,train.txt、val.txt分别存入图片名称及标签
读取txt文件中的路径和标签
首先,从文件中读取图片名称及标签,先暂时存储为字典结构
def load_annotations(ann_file):
data_infos = {}
with open(ann_file) as f:
samples = [x.strip().split(' ') for x in f.readlines()]
for filename, gt_label in samples:
data_infos[filename] = np.array(gt_label, dtype=np.int64)
return data_infos
分别将图片名及标签存入列表中
写入完整图片路径
用dataloader实现
- 1.注意要使用from torch.utils.data import Dataset, DataLoader
- 2.类名定义class FlowerDataset(Dataset),其中FlowerDataset可以改成自己的名字
- 3.def init(self, root_dir, ann_file, transform=None):咱们要根据自己任务重写
- 4.def getitem(self, idx):根据自己任务,返回图像数据和标签数据
其中,getitem中的idx表示,将数据shuffle后,取的索引
from torch.utils.data import Dataset, DataLoader
class FlowerDataset(Dataset):
def __init__(self, root_dir, ann_file, transform=None):
self.ann_file = ann_file
self.root_dir = root_dir
self.img_label = self.load_annotations()
self.img = [os.path.join(self.root_dir,img) for img in list(self.img_label.keys())]
self.label = [label for label in list(self.img_label.values())]
self.transform = transform
def __len__(self):
return len(self.img)
def __getitem__(self, idx):
image = Image.open(self.img[idx])
label = self.label[idx]
if self.transform:
image = self.transform(image)
label = torch.from_numpy(np.array(label))
return image, label
def load_annotations(self):
data_infos = {}
with open(self.ann_file) as f:
samples = [x.strip().split(' ') for x in f.readlines()]
for filename, gt_label in samples:
data_infos[filename] = np.array(gt_label, dtype=np.int64)
return data_infos
数据预处理(transform)
- 预处理的事都在上面的getitem中完成,返回的数据和标签就是建模时模型的输入和损失函数中标签的输入
data_transforms = {
'train':
transforms.Compose([
transforms.Resize(64),
transforms.RandomRotation(45),#随机旋转,-45到45度之间随机选
transforms.CenterCrop(64),#从中心开始裁剪
transforms.RandomHorizontalFlip(p=0.5),#随机水平翻转 选择一个概率概率
transforms.RandomVerticalFlip(p=0.5),#随机垂直翻转
transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),#参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
transforms.RandomGrayscale(p=0.025),#概率转换成灰度率,3通道就是R=G=B
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])#均值,标准差
]),
'valid':
transforms.Compose([
transforms.Resize(64),
transforms.CenterCrop(64),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
根据写好的class FlowerDataset(Dataset):来实例化dataloader
- 1.构建数据集:分别创建训练和验证用的数据集(如果需要测试集也一样的方法)
- 2.用Torch给的DataLoader方法来实例化(batch啥的自己定,根据你的显存来选合适的)
- 3.打印看看数据的情况
检验一下数据
image, label = iter(train_loader).next()
sample = image[0].squeeze()
sample = sample.permute((1, 2, 0)).numpy() # 将channels从第一个维度变为第三个维度,gpu要加一个.cuda.cpu方法
sample *= [0.229, 0.224, 0.225]
sample += [0.485, 0.456, 0.406]
plt.imshow(sample)
plt.show()
print('Label is: {}'.format(label[0].numpy()))
最后传入模型进行训练就可以了