深度学习|Dataloader 读取txt标签信息

读取txt文件中的标签信息

加载图片数据存放的位置

import os
import matplotlib.pyplot as plt

import numpy as np
import torch
from PIL import Image

train_dir = 'train_filelist'
valid_dir = 'val_filelist'

 具体见代码注释

from torch.utils.data import Dataset, DataLoader


class FlowerDataset(Dataset):
    def __init__(self, root_dir, ann_file, transform=False):
        self.ann_file = ann_file
        self.root_dir = root_dir
        self.img_label = self.load_annotations()
        self.img = [os.path.join(self.root_dir, img) for img in list(self.img_label.keys())]    #将文件名和img结合起来成为路径 存在img
        self.label = [label for label in list(self.img_label.values())]                         #将字典中values值取出放在label中
        self.transform = transform                                                              #数据预处理

    def __len__(self):
        return len(self.img)

    def __getitem__(self, idx):           #idx=index  索引
        image = Image.open(self.img[idx])
        label = self.label[idx]
        if self.transform:
            image = self.transform(image)            #如果需要处理,就先处理  变成tensor形式
        label = torch.from_numpy(np.array(label))     #将lablel转换成张量形式
        return image, label

    def load_annotations(self):
        data_infos = {}
        with open(self.ann_file) as f:
            samples = [x.strip().split(' ') for x in f.readlines()]     #读取标签文件 空格作为分隔符 一条一条读取 放在sample里 此时sample里有 文件名 和标签 list形式
            for filename, gt_label in samples:                          #从list中读取 文件名和标签 存在data_info中 字典形式, 一个key 对应一个value
                data_infos[filename] = np.array(gt_label, dtype=np.int64)
        return data_infos                                               #返回字典结构
train_dataset = FlowerDataset(root_dir=train_dir, ann_file = 'train.txt', transform=data_transforms['train_filelist'])

val_dataset = FlowerDataset(root_dir=valid_dir, ann_file = 'val.txt', transform=data_transforms['val_filelist'])


train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=True)

dataiter_train=iter(train_loader)
image, label = next(dataiter_train)

  • 12
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

孔雀飞

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值