pytorch-迭代器

class FeatDataTrain(Dataset):
    def __init__(self, file_list , batch_size):
        # image_filenames - 音频路径集合
        self.filenames = self.load_files( file_list)
        self.batch_size = batch_size

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        file_name = self.filenames[idx]
        dir = file_name[0]
        x = cPickle.load(open(dir, 'rb'))
        x = x.astype(np.float32)
        y = int(dir.split('/')[-1].split('.')[0].split('_')[0][0])
        y= y - 1  # 若num_classes=3,则y应为[0,1,2]
        return x ,y

    def load_files(self, csv_path):
        filpath = []
        if csv_path != "":
            f = open(csv_path)
            data = csv.reader(f)  #
            for line in data:
                tmp = line
                filpath.append(tmp)
        else:
            print('没有特征文件!')
        return filpath

最后一个Batch长度不足,会导致输出维度发生问题.drop_last (bool, optional): 如果设置为True:这个是对最后的未完成的batch来说的,比如你的batch_size设置为64,而一个epoch只有100个样本,那么训练的时候后面的36个就被扔掉了

通过读取文件夹路径分配文件。

(2)

loader = DataLoader(dataset,
                    num_workers=opt.num_workers, drop_last=True,
                    shuffle=True, batch_size= opt.batch_size)

在模型测试阶段,如果shuffle=False,会出现意想不到的情况。

参考:

(1)FancyKeras-数据的输入(花式) - 知乎

(2)model.fit以及model.fit_generator区别及用法_猫爱吃鱼the的博客-CSDN博客_model.fit_generator参数

(3)DataLoader详解_sereasuesue的博客-CSDN博客_dataloader

(4)PyTorch 1.x 常用知识_Arrow的博客-CSDN博客

(5)pytorch官方教程中文版(一)PyTorch介绍_努力永远不会欺骗人的博客-CSDN博客_pytorch官方教程中文版

(6)PyTorch简介 - PyTorch官方教程中文版

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值